In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from voc import get_dataloader
from main_utils import set_seed
from model_factory import get_model
from ema import build_ema, EMA

set_seed(42)

In [2]:
# SIZE = (256, 256)
import os
SIZE = (160, 160)
CONFIDENCE_THRESHOLD = 0.7
BATCH_SIZE = 4
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def scale_to_01(image, **kwargs):
    return image.astype('float32') / 255.0

train_labeled_transforms = A.Compose([
    A.Resize(SIZE[0], SIZE[1]),         
    A.HorizontalFlip(p=0.5),
    A.Lambda(image=scale_to_01), 
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc'))

train_unlabeled_transforms = A.Compose(
        [
            A.Resize(SIZE[0], SIZE[1]),
            A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
            A.GaussianBlur(blur_limit=(3, 7), sigma_limit=(0.1, 2.0), p=0.5),
            A.CoarseDropout(num_holes_range=(3, 3), hole_height_range=(0.05, 0.1),
                             hole_width_range=(0.05, 0.1), p=0.5),
            A.Lambda(image=scale_to_01), 
            ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format='pascal_voc')
    )

test_transforms = A.Compose([
    A.Resize(SIZE[0], SIZE[1]),
    A.Lambda(image=scale_to_01), 
    ToTensorV2(), 
], bbox_params=A.BboxParams(format='pascal_voc'))



dt_train_labeled = get_dataloader("trainval", "2007", BATCH_SIZE, transform=train_labeled_transforms)
dt_train_unlabeled_weakaug = get_dataloader("trainval", "2012", BATCH_SIZE, transform=train_labeled_transforms) 
dt_train_unlabeled_strongaug = get_dataloader("trainval", "2012", BATCH_SIZE, transform=train_unlabeled_transforms)
dt_test = get_dataloader("test", "2007", BATCH_SIZE, transform=test_transforms, shuffle=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
import matplotlib.pyplot as plt
METRIC_KEYS = ["loss_classifier", "loss_box_reg", "loss_objectness", "loss_rpn_box_reg", "total"]

def plot_losses(history):
    epochs = range(1, len(history["total"]) + 1)

    plt.figure(figsize=(7, 5))
    for comp in METRIC_KEYS:
        plt.plot(epochs, history[comp], label=f"Train {comp}", linewidth=2)
    plt.title(f"Train results over epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()


In [4]:
from tqdm import tqdm
import os

def load_checkpoint(checkpoint_path, optimizer=None, device='cuda'):
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model = get_model(device=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model weights loaded from {checkpoint_path}")

    if optimizer and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Optimizer state loaded from {checkpoint_path}")

    epoch = checkpoint.get('epoch', 0)
    print(f"Resuming from epoch {epoch}")
    return model, optimizer, epoch

def train(model, optimizer, dt_train_labeled, device):
    model.train()
    train_batches = 0
    history = {key : 0 for key in METRIC_KEYS}

    for images, targets in tqdm(dt_train_labeled, desc="Training"):
        # if train_batches == 5: break
        for target in targets:
            target["boxes"] = target["boxes"].to(device)
            target["labels"] = target["labels"].to(device)
        images = images.to(device)
        loss_dict = model(images, targets)
        loss = sum(loss_dict.values())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for k, v in loss_dict.items():
            history[k] += v.item()

        history["total"] += loss.item()
        train_batches += 1
    for key in history:
        history[key] = history[key] / train_batches
    return history


def save_checkpoint(model, optimizer, epoch, path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved at {path}")


def pipeline(epochs, dt_train_labeled, device, checkpoint_every):

    model = get_model(device=device)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    history = {key : [] for key in METRIC_KEYS}

    for epoch in range(epochs):
        print(f"\n==================== Epoch {epoch+1}/{epochs} ====================\n")
        train_history = train(model, optimizer, dt_train_labeled, device)
        lr_scheduler.step(train_history["total"])
        for key, val in train_history.items():
            history[key].append(val)
        plot_losses(history)
        if (epoch + 1) % checkpoint_every == 0 or (epoch + 1) == epochs:
            checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1}.pth")
            save_checkpoint(model, optimizer, epoch + 1, checkpoint_path)

# pipeline(30, dt_train_labeled, device, 3)

In [None]:
images, labels = next(iter(dt_train_unlabeled_weakaug))
from torchvision.ops import batched_nms
NMS_IOU = 0.5
def generate_pseudo_labels(model : torch.nn.Module, images : torch.Tensor, device):
    model.eval()
    with torch.no_grad():
        images = images.to(device)
        outputs = model(images, None)
        for output in outputs:
            boxes  = output["boxes"]
            labels = output["labels"]
            scores = output["scores"]

            keep_nms = batched_nms(
                boxes, scores, labels,
                iou_threshold=NMS_IOU
            )
            boxes  = boxes[keep_nms]
            labels = labels[keep_nms]
            scores = scores[keep_nms]

            boxes_to_keep = scores > CONFIDENCE_THRESHOLD        
            boxes  = boxes[boxes_to_keep]
            labels = labels[boxes_to_keep]
            scores = scores[boxes_to_keep]

            output["boxes"]  = boxes
            output["labels"] = labels
            output["scores"] = scores
        return outputs       
    
# model, optimizer, epoch = load_checkpoint(checkpoint_path=checkpoint_path, optimizer=None, device=device)
# generate_pseudo_labels(model, images, device)

In [7]:
def train_self_supervised_one_epoch(teacher : EMA, student, optimizer, dt_weak, dt_strong):
    student.train()
    train_batches = 0
    history = {key : 0 for key in ["loss_classifier", "loss_objectness", "total"]}
    for (img_weak, _), (img_strong, _) in zip(dt_weak, dt_strong):
        # SHOULD REPLACE THE TRANSFORMATION OF HORIZONTAL FLIP WITH SOMETHING PHOTOMETRIC
        weak_targets = generate_pseudo_labels(teacher.ema, img_weak, device)
        
        for target in weak_targets:
            target["boxes"] = target["boxes"].to(device)
            target["labels"] = target["labels"].to(device)
        img_strong = img_strong.to(device)
        loss_dict = student(img_strong, weak_targets)
        loss = sum(loss_dict.values())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        teacher.update(student)
        for k, v in loss_dict.items():
            history[k] += v.item()

        history["total"] += loss.item()
        train_batches += 1
    for key in history:
        history[key] = history[key] / train_batches
    return history


def run_semi_supervised_pipeline(checkpoint_path, epochs, dt_weak, dt_strong):
    student, _, _ = load_checkpoint(checkpoint_path=checkpoint_path, optimizer=None, device=device)
    teacher = build_ema(student)
    optimizer = torch.optim.SGD(student.parameters(), lr=1e-2, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    history = {key : [] for key in METRIC_KEYS}
    
    for epoch in range(epochs):
        print(f"\n==================== Epoch {epoch+1}/{epochs} ====================\n")
        train_history = train_self_supervised_one_epoch(teacher, student, optimizer, dt_weak, dt_strong)
        lr_scheduler.step(train_history["total"])
        for key, val in train_history.items():
            history[key].append(val)

checkpoint_path="checkpoints/checkpoint_epoch_42.pth"
run_semi_supervised_pipeline(checkpoint_path, 10, dt_train_unlabeled_weakaug, dt_train_unlabeled_strongaug)


Model weights loaded from checkpoints/checkpoint_epoch_42.pth
Resuming from epoch 42




KeyError: 'loss_box_reg'