In [1]:
from torchvision.datasets import CocoDetection
from torchvision.datasets import VOCDetection
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch

class COCO(Dataset):
    def __init__(self, image_set='train'):
        root = f"../../Datasets/COCO/{image_set}2017" 
        annFile = f"../../Datasets/COCO/annotations/instances_{image_set}2017.json"
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.dataset = CocoDetection(root=root, annFile=annFile, transform=transform)
        self.n_classes = 91
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        y = torch.zeros(self.n_classes)
        image, labels = self.dataset[idx]

        for lbl in labels:
            y[lbl['category_id']-1] = 1
            
        return image, y

class VOC(Dataset):
    def __init__(self, image_set='train', label_policy='all'):
        transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        self.label_set = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
        self.dataset = VOCDetection(root="./data", year="2012", image_set=image_set, download=True, transform=transform)
        self.n_classes = 20
        self.label_policy = label_policy
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        y = torch.zeros(self.n_classes)
        image, labels = self.dataset[idx]

        if self.label_policy=="first":
            y[self.label_set.index(labels['annotation']['object'][0]['name'])-1] = 1
        else:
            for lbl in labels['annotation']['object']:
                y[self.label_set.index(lbl['name'])-1] = 1
                
        return image, y

In [2]:
train = VOC('train', 'all')
train_first = VOC('train', 'first')
val = VOC('val', 'all')

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import timm
import torch

In [4]:
from augmentations import Masking, GaussianNoise, Stretch
from data import get_datasets
from AdaMatch import AdaMatch
from Baseline import Baseline
from FixMatch import FixMatch
from model import Model
from metrics import AUC, Error_Rate

In [5]:
# Hyperparameters
epochs = 1024
learning_rate = 1e-3
momentum = 0.9
mu = 7
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

batch_size = 64

In [6]:
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
train_loader_first = DataLoader(train_first, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(val, batch_size=batch_size, shuffle=False)
#unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size*mu, shuffle=True)

In [7]:
weak = transforms.Compose([
    tor
])

In [8]:
config = {'backbone':'tf_efficientnetv2_b0','n_channels':3}
model = Model(train.n_classes, config=config)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-8)
criterion = nn.functional.binary_cross_entropy_with_logits
metrics = [AUC,Error_Rate]

trainer = Baseline(weak_transform=weak)
model, log_baseline = trainer.train(train_loader, test_loader, model, optimizer, criterion, metrics, scheduler=scheduler, epochs=epochs, verbose=2, val_freq=1)

Epoch 278/1024
[1m Training 	|	 loss=12.445[0m
[1m Validation 	|	 loss=25.712  -  AUC=0.675  -  Error_Rate=0.811[0m

[1m Best : AUC=0.705  -  Error_Rate=0.822 at epoch 143


Training:  87%|██████████████████████████████████████████▍      | 78/90 [00:15<00:02,  5.12it/s]


KeyboardInterrupt: 