In [None]:
# import statements for python, torch and companion libraries and your own modules
import os
from glob import glob
from pathlib import Path
import json
from PIL import Image

import torch

### torchinfo
from torchinfo import summary

from matplotlib.colors import to_rgb
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['lines.linewidth'] = 2.0

In [None]:
# global variables defining training hyper-parameters among other things 
model_kwargs={
                                'embed_dim': 256,
                                'hidden_dim': 512,
                                'num_heads': 8,
                                'num_layers': 6,
                                'patch_size': 4,
                                'num_channels': 3,
                                'num_patches': 64,
                                'num_classes': 80,
                                'dropout': 0.2
                            }


In [None]:
# device initialization



In [None]:
# data directories initialization
class COCOTrainImageDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, annotations_dir, max_images=None, transform=None):
        self.img_labels = sorted(glob("*.cls", root_dir=annotations_dir))
        if max_images:
            self.img_labels = self.img_labels[:max_images]
        self.img_dir = img_dir
        self.annotations_dir = annotations_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, Path(self.img_labels[idx]).stem + ".jpg")
        labels_path = os.path.join(self.annotations_dir, self.img_labels[idx])
        image = Image.open(img_path).convert("RGB")
        with open(labels_path) as f: 
            labels = [int(label) for label in f.readlines()]
        if self.transform:
            image = self.transform(image)
        labels = torch.zeros(80).scatter_(0, torch.tensor(labels), value=1)
        return image, labels


class COCOTestImageDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_list = sorted(glob("*.jpg", root_dir=img_dir))    
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_list[idx])
        image = Image.open(img_path).convert("RGB")        
        if self.transform:
            image = self.transform(image)
        return image, Path(img_path).stem # filename w/o extension


In [None]:
# instantiation of transforms, datasets and data loaders
# TIP : use torch.utils.data.random_split to split the training set into train and validation subsets

# train_img : val_img = 9:1


In [None]:
# class definitions
classes = ("person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", 
           "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
           "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",       
           "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
           "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
           "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", 
           "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", 
           "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", 
           "hair drier", "toothbrush")


In [None]:
# instantiation and preparation of network model

# swintransformer V2 small
from torchvision.models import swin_v2_s,Swin_V2_S_Weights

model = swin_v2_s(weights = Swin_V2_S_Weights.DEFAULT)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(?, 80)

import torchvision.transforms as transforms 
transform = transforms.Compose([transforms.ToTensor(),
                                # other transforms,
                                Swin_V2_S_Weights.DEFAULT.transforms()])

In [None]:
# instantiation of loss criterion
# instantiation of optimizer, registration of network parameters



In [None]:
# definition of current best model path
# initialization of model selection metric



In [None]:
# creation of tensorboard SummaryWriter (optional)


In [None]:
# epochs loop:
#   train
#   validate on train set
#   validate on validation set
#   update graphs (optional)
#   is new model better than current model ?
#       save it, update current best metric



In [None]:
# close tensorboard SummaryWriter if created (optional)

