In [1]:
import torch
#import kagglehub
from torchvision.transforms import v2
import os #for loading the data
from torch import nn
from torch.nn import functional as F

In [None]:
from torchvision.models.resnet import ResNet,BasicBlock,Bottleneck,wide_resnet50_2

In [None]:
path = ''


In [4]:

from utils import create_weak_aug,create_strong_aug,create_valid_transform
weak = create_weak_aug(size = (224,224))
strong = create_strong_aug(size = (224,224))
valid_transform = create_valid_transform(size=(224,224))

In [5]:
train_path = os.path.join(path,'train')
valid_path = os.path.join(path,'validation')
test_path = os.path.join(path,'test')

We import the datasets with the created functions and classes.

In [6]:
from datasets import unlabelled_TensorDataset,labelled_TensorDataset


labelled_set = labelled_TensorDataset(name = valid_path, transform=weak)
unlabelled_set = unlabelled_TensorDataset(name = train_path,transform=weak,target_transform=strong)


Premature end of JPEG file


The labelled set is then splitted by number of images for which label, the default is always 100.

We define our training parameters, using FixMatch's original paper https://arxiv.org/abs/2001.07685 as some influence.



In [7]:
num_workers = 2
epochs=100
#Training params and FixMatch hyperParams
batch_size = 16 #used for labelled data
ratio = 4 #this is the main limitation, due to the GPU's memory capacity.
loss_weight = 1.0
#Optim Parameters
lr = 1e-3
momentum = 0.5
weight_decay = 0.03
nesterov=True

In [8]:
from torch.utils.data import DataLoader


unlabel_loader = DataLoader(unlabelled_set,batch_size=int(ratio*batch_size),shuffle=True, pin_memory=True, num_workers=num_workers,persistent_workers=True)


As in FixMatch's paper, the threshold is used to know if the pseudolabels for each image will be used. Here, to keep using PyTorch's implementation of the CrossEntropy, when the prediction over the weak augmented version of the input is inferior to the threshold paremeter, its label will be 3, and therefore ignored during the reduction.

In [9]:
from torch.nn import CrossEntropyLoss

criterion = CrossEntropyLoss(ignore_index=3,reduction='none')#target is assumed to be a list of indexes in [0,C)(C is the number of classes)
criterion.ignore_index

3

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [None]:
from utils import epoch_loop,validate
from itertools import product
from utils import labelset_split


models = ['resnet']
thresh_vals = [0.5]
label_samples = [100,250,500]
for model,value,n_labels in product(models,thresh_vals,label_samples):
    threshold = value
    save_path = model+str(value)+"_"+str(n_labels)+".pth.tar"
    print("We are training "+save_path)
    labelled_filtered_set,val_set = labelset_split(labelled_set,n_per_label=n_labels)
    label_loader = DataLoader(labelled_filtered_set,batch_size=batch_size,shuffle=True, pin_memory=True, num_workers=num_workers,persistent_workers=True)
    val_loader = DataLoader(val_set,batch_size=batch_size,shuffle=True, pin_memory=True, num_workers=num_workers,persistent_workers=True)

    if model == 'resnet':
        model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes = 2)#Configurations for WideResNet50
    model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=lr,momentum=momentum, weight_decay=weight_decay,nesterov=nesterov)

    epoch_loop(model,label_loader,unlabel_loader,val_loader,optimizer,criterion,device,epochs,threshold,loss_weight,verbose=False,save_path=save_path)


We are training resnet0.5_100.pth.tar
Fixmatch with threshold:  0.5
[34m[Epoch: 1/100][0m Training
[34m[Epoch: 1/100][0m Avg loss: 1.0452 | Accuracy: 76.2347
[34m[Epoch: 1/100][0m Validation
[34m[Epoch: 1/100][0m Avg loss: 0.4155 | Accuracy: 12.7225

[32mBest model so far. Saving model as model.pth[0m

[34m[Epoch: 3/100][0m Training
[34m[Epoch: 3/100][0m Avg loss: 0.6052 | Accuracy: 78.1924
[34m[Epoch: 3/100][0m Validation
[34m[Epoch: 3/100][0m Avg loss: 0.2749 | Accuracy: 14.3429

[32mBest model so far. Saving model as model.pth[0m

[34m[Epoch: 5/100][0m Training
[34m[Epoch: 5/100][0m Avg loss: 0.3887 | Accuracy: 79.1184
[34m[Epoch: 5/100][0m Validation
[34m[Epoch: 5/100][0m Avg loss: 0.2682 | Accuracy: 14.5969

[32mBest model so far. Saving model as model.pth[0m

[34m[Epoch: 36/100][0m Training
[34m[Epoch: 36/100][0m Avg loss: 0.2104 | Accuracy: 79.3235
[34m[Epoch: 36/100][0m Validation
[34m[Epoch: 36/100][0m Avg loss: 0.2617 | Accuracy: 14.8403


: 