In [1]:
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection import MaskRCNN

from maskrcnn import *

In [2]:
train_root = "data/Shampoo_5class.v2-only_non_defective-21-06-2024.coco/train"
val_root = "data/Shampoo_5class.v2-only_non_defective-21-06-2024.coco/valid"
test_root = "data/Shampoo_5class.v2-only_non_defective-21-06-2024.coco/test"

transform = Compose([
    ToTensor(),
    RandomHorizontalFlip(0.5),
])

train_dataset = CustomDataset(root=train_root, transforms=transform, max_samples=1)
val_dataset = CustomDataset(root=val_root, transforms=transform, max_samples=1)
test_dataset = CustomDataset(root=test_root, transforms=transform, max_samples=1)

In [3]:
def collate_fn(batch):
    return tuple(zip(*batch))


train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, collate_fn=collate_fn)

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    # device = torch.device("mps")
    device = torch.device("cpu")
else:
    device = torch.device("cpu")

In [None]:
backbone = resnet_fpn_backbone('resnet50', pretrained=True)
model = MaskRCNN(backbone, num_classes=6)
model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

In [None]:
EPOCHS = 1

model.train()
for epoch in tqdm(range(EPOCHS)):
    train_loss = 0.0
    
    for images, targets in tqdm(train_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss_dict = model(images, targets)

        # Compute total loss
        losses = sum(loss for loss in loss_dict.values())

        # Print individual losses for debugging or monitoring
        for key, value in loss_dict.items():
            print(f"{key}: {value.item()}")
        print("------------------------------------------------->")

        losses.backward()
        optimizer.step()

        train_loss += losses.item()

    train_loss /= len(train_loader)

In [8]:
# TODO move to inside train loop, here for testing purposes
# TODO migrate to PyTorch Lightning?

model.eval()
val_loss = 0.0

with torch.no_grad():
    for images, targets in tqdm(val_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        for item in loss_dict:
            print(item)
            
        # TODO loss must be calculated here
        losses = sum(loss for loss in loss_dict.values())

        val_loss += losses.item()

val_loss /= len(val_loader)

print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss}, Val Loss: {val_loss}")

  0%|          | 0/1 [00:01<?, ?it/s]

{'boxes': tensor([[4.8168e+02, 3.9948e+02, 6.5818e+02, 4.8509e+02],
        [4.6477e+02, 4.3070e+02, 6.4606e+02, 5.1593e+02],
        [2.0057e+03, 5.6887e+01, 2.0688e+03, 1.1683e+02],
        [4.8436e+02, 4.7858e+02, 6.6091e+02, 5.6511e+02],
        [5.1432e+02, 5.2811e+02, 5.7222e+02, 5.8795e+02],
        [2.0278e+03, 5.5685e+01, 2.0918e+03, 1.1625e+02],
        [7.7886e+01, 4.3289e+02, 2.4684e+02, 5.1135e+02],
        [1.8858e+03, 5.0925e+02, 2.0627e+03, 5.9290e+02],
        [1.5789e+03, 3.4311e+02, 1.8335e+03, 5.8596e+02],
        [1.6233e+03, 4.5762e+02, 1.6835e+03, 5.1936e+02],
        [1.6451e+03, 4.6380e+02, 1.7064e+03, 5.2271e+02],
        [6.2025e+01, 1.8235e+02, 3.4062e+02, 4.2989e+02],
        [8.0622e+02, 4.6136e+02, 9.8349e+02, 5.5147e+02],
        [1.6223e+03, 4.8060e+02, 1.6841e+03, 5.4114e+02],
        [4.5638e+02, 5.0625e+02, 6.2649e+02, 5.8809e+02],
        [1.9394e+03, 4.1227e+02, 2.1070e+03, 5.0053e+02],
        [1.7302e+02, 5.5923e+02, 3.4372e+02, 6.4009e+02],
    




AttributeError: 'list' object has no attribute 'values'