In [None]:
import torch
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torchvision.models.segmentation import fcn_resnet50

from data.pascal_voc_dataset import PascalVOCSegmentation
from data.utils import get_pascal_dataloader
from data.pascal_data_loader import PascalVOCLoader

In [None]:
! setenv CUDA_VISIBLE_DEVICES 0,1

In [None]:
LEARNING_RATE = 2.5e-3
BATCH_SIZE = 32
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
POWER = 0.9
TOTAL_ITERATIONS = 40000
MAX_EPOCHS = 100 
DATA_ROOT = "."
NUM_WORKERS = 0

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
loader = PascalVOCLoader(
    data_path="/local/home/sanjee23/dev/ReCo/dataset/pascal")
train_loader, val_loader = loader.create_loaders(use_unlabeled=False)

In [None]:
len(train_loader)

In [None]:
model = fcn_resnet50(pretrained=True).to(device)

optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, 
                      momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss(ignore_index=-1) 

In [None]:
def compute_iou(outputs, targets):
    smooth = 1e-6
    preds = torch.argmax(outputs, dim=1)
    
    ious = []
    for cls in range(1, 21): 
        pred_inds = preds == cls
        target_inds = targets == cls
        
        intersection = (pred_inds & target_inds).float().sum()
        union = (pred_inds | target_inds).float().sum()
        
        if union.item() > 0:
            iou = (intersection + smooth) / (union + smooth)
            ious.append(iou.item())
    
    return np.mean(ious) if ious else 0

In [None]:
def adjust_learning_rate(optimizer, initial_lr, iter, total_iter, power=0.9):
    lr = initial_lr * (1 - iter / total_iter) ** power
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

In [None]:
val_interval = 2000  
best_iou = 0
total_iterations = 0

pbar = tqdm(total=TOTAL_ITERATIONS)

for epoch in range(MAX_EPOCHS):
    print(f"Epoch {epoch+1}/{MAX_EPOCHS}")
    
    for img_mask in train_loader:
        model.train()
        
        img = img_mask[0].float().to(device)
        mask = img_mask[1].long().to(device)
        
        current_lr = adjust_learning_rate(optimizer, LEARNING_RATE, total_iterations, TOTAL_ITERATIONS, POWER)
        
        y_pred = model(img)['out']
        optimizer.zero_grad()
        
        loss = criterion(y_pred, mask)
        
        loss.backward()
        optimizer.step()
        
        total_iterations += 1
        pbar.update(1)
        pbar.set_description(f"Epoch: {epoch+1}/{MAX_EPOCHS}, Iter: {total_iterations}/{TOTAL_ITERATIONS}, Loss: {loss.item():.4f}, LR: {current_lr:.6f}")
        
        if total_iterations % val_interval == 0:
            model.eval()
            val_running_loss = 0
            val_iou = 0
            
            with torch.no_grad():
                for idx, img_mask in enumerate(val_loader):
                    img = img_mask[0].float().to(device)
                    mask = img_mask[1].long().to(device)
                    
                    y_pred = model(img)['out']
                    loss = criterion(y_pred, mask)
                    
                    batch_iou = compute_iou(y_pred, mask)
                    val_iou += batch_iou
                    val_running_loss += loss.item()
                
                val_loss = val_running_loss / (idx + 1)
                mean_iou = val_iou / (idx + 1)
            
            print("-"*50)
            print(f"Epoch: {epoch+1}/{MAX_EPOCHS}, Iteration: {total_iterations}/{TOTAL_ITERATIONS}")
            print(f"Validation Loss: {val_loss:.4f}, Mean IoU: {mean_iou:.4f}")
            print("-"*50)
            
            if mean_iou > best_iou:
                best_iou = mean_iou
                torch.save(model.state_dict(), "checkpoints/best_model.pth")
        
        if total_iterations % 10000 == 0:
            torch.save({
                'epoch': epoch + 1,
                'iteration': total_iterations,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, f"checkpoints/checkpoint_iter_{total_iterations}.pth")
        
        if total_iterations >= TOTAL_ITERATIONS:
            break
    
    if total_iterations >= TOTAL_ITERATIONS:
        print(f"Reached {TOTAL_ITERATIONS} iterations. Stopping training.")
        break

pbar.close()
print(f"Training completed! Best validation IoU: {best_iou:.4f}")
torch.save(model.state_dict(), "checkpoints/final_model.pth")