In [1]:
import torch
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.models.segmentation import fcn_resnet50

from data.pascal_data_loader import PascalVOCLoader
from network.mean_ts import TeacherModel
from train_utils import  *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
LEARNING_RATE = 2.5e-3 
BATCH_SIZE = 32 
MOMENTUM = 0.9 
WEIGHT_DECAY = 5e-4 
POWER = 0.9 
TOTAL_ITERATIONS = 40000
MAX_EPOCHS = 100
EMA_DECAY = 0.99
CONF_THRESH = 0.95
UNSUP_WEIGHT = 0.5


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
student_model = fcn_resnet50(pretrained=True).to(device)

teacher_model = TeacherModel(student_model, ema_decay=EMA_DECAY).to(device)

optimizer = optim.SGD(student_model.parameters(), 
                     lr=LEARNING_RATE, 
                     momentum=MOMENTUM, 
                     weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss(ignore_index=-1)

In [4]:
loader = PascalVOCLoader(
    data_path="/local/home/sanjee23/dev/ReCo/dataset/pascal")
train_labeled_loader, train_unlabeled_loader, val_loader = loader.create_loaders(use_unlabeled=True)

In [6]:
val_interval = 2000
best_iou = 0
total_iterations = 0

pbar = tqdm(total=TOTAL_ITERATIONS)

for epoch in range(MAX_EPOCHS):
    print(f"Epoch {epoch+1}/{MAX_EPOCHS}")
    
    labeled_iter = iter(train_labeled_loader)
    unlabeled_iter = iter(train_unlabeled_loader)
    
    for labeled_batch in train_labeled_loader:
        try:
            unlabeled_batch = next(unlabeled_iter)
        except StopIteration:
            unlabeled_iter = iter(train_unlabeled_loader)
            unlabeled_batch = next(unlabeled_iter)
        
        labeled_img = labeled_batch[0].float().to(device)
        labeled_mask = labeled_batch[1].long().to(device)
        unlabeled_img = unlabeled_batch[0].float().to(device)
        
        current_lr = adjust_learning_rate(optimizer, LEARNING_RATE, total_iterations, TOTAL_ITERATIONS, POWER)
        
        student_model.train()
        student_labeled_output = student_model(labeled_img)['out']
        
        supervised_loss = criterion(student_labeled_output, labeled_mask)
        
        pseudo_labels, conf_mask, confidence = teacher_model.generate_pseudo_labels(
            unlabeled_img, confidence_threshold=CONF_THRESH)
        
        student_unlabeled_output = student_model(unlabeled_img)['out']
        
        unsupervised_loss = calculate_unsupervised_loss(
            student_unlabeled_output, pseudo_labels, conf_mask)
        
        conf_ratio = conf_mask.float().mean().item()
        
        eta = conf_ratio if conf_ratio > 0 else 0.1
        total_loss = supervised_loss + UNSUP_WEIGHT * eta * unsupervised_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        teacher_model.update_weights(student_model)
        
        total_iterations += 1
        pbar.update(1)
        pbar.set_description(
            f"Epoch: {epoch+1}/{MAX_EPOCHS}, Iter: {total_iterations}/{TOTAL_ITERATIONS}, "
            f"Loss: {total_loss.item():.4f}, Sup: {supervised_loss.item():.4f}, "
            f"Unsup: {unsupervised_loss.item():.4f}, LR: {current_lr:.6f}, Conf: {conf_ratio:.2f}"
        )
        
        if total_iterations % val_interval == 0:
            student_model.eval()
            val_running_loss = 0
            val_iou = 0
            
            with torch.no_grad():
                for idx, img_mask in enumerate(val_loader):
                    img = img_mask[0].float().to(device)
                    mask = img_mask[1].long().to(device)
                    
                    y_pred = student_model(img)['out']
                    loss = criterion(y_pred, mask)
                    
                    batch_iou = compute_iou(y_pred, mask)
                    val_iou += batch_iou
                    val_running_loss += loss.item()
                
                val_loss = val_running_loss / (idx + 1)
                mean_iou = val_iou / (idx + 1)
            
            print("-"*50)
            print(f"Epoch: {epoch+1}/{MAX_EPOCHS}, Iteration: {total_iterations}/{TOTAL_ITERATIONS}")
            print(f"Validation Loss: {val_loss:.4f}, Mean IoU: {mean_iou:.4f}")
            print("-"*50)
            
            if mean_iou > best_iou:
                best_iou = mean_iou
                torch.save(student_model.state_dict(), "checkpoints/best_student_model.pth")
        
        if total_iterations % 10000 == 0:
            torch.save({
                'epoch': epoch + 1,
                'iteration': total_iterations,
                'student_model_state_dict': student_model.state_dict(),
                'teacher_model_state_dict': teacher_model.model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': total_loss,
            }, f"checkpoints/checkpoint_iter_{total_iterations}.pth")
        
        if total_iterations >= TOTAL_ITERATIONS:
            break
    
    if total_iterations >= TOTAL_ITERATIONS:
        print(f"Reached {TOTAL_ITERATIONS} iterations. Stopping training.")
        break

pbar.close()
print(f"Training completed! Best validation IoU: {best_iou:.4f}")
torch.save(student_model.state_dict(), "checkpoints/final_student_model.pth")

Epoch: 1/100, Iter: 2/40000, Loss: 0.3147, Sup: 0.3070, Unsup: 0.0225, LR: 0.002500, Conf: 0.68:   0%|          | 2/40000 [00:49<275:50:52, 24.83s/it]

Epoch 1/100





Epoch 2/100




Epoch 3/100




Epoch 4/100




Epoch 5/100




Epoch 6/100




Epoch 7/100




Epoch 8/100




Epoch 9/100




Epoch 10/100




--------------------------------------------------
Epoch: 10/100, Iteration: 2000/40000
Validation Loss: 0.8028, Mean IoU: 0.1962
--------------------------------------------------
Epoch 11/100




Epoch 12/100




Epoch 13/100




Epoch 14/100




Epoch 15/100




Epoch 16/100




Epoch 17/100




Epoch 18/100




Epoch 19/100




Epoch 20/100




--------------------------------------------------
Epoch: 20/100, Iteration: 4000/40000
Validation Loss: 0.9212, Mean IoU: 0.1648
--------------------------------------------------
Epoch 21/100




Epoch 22/100




KeyboardInterrupt: 