# List of models

In [None]:
import timm 
print(timm.list_models())

# Import module

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision.transforms.functional import rotate
import wandb
from tqdm import tqdm
import os
import warnings
import torch
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
from PIL import Image
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, RandAugment
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation, RandomResizedCrop, ToTensor, Normalize


os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
warnings.filterwarnings("ignore", category=UserWarning)

# Loading data and data augmentation

In [None]:
class DataModule:
    def __init__(
        self,
        labeled_train_dataset_path,
        unlabeled_train_dataset_path,
        train_transform,
        batch_size,
        num_workers,
    ):

        self.train_transform = train_transform
        self.labeled_train_dataset_path = labeled_train_dataset_path
        self.unlabeled_train_dataset_path = unlabeled_train_dataset_path
        self.batch_size = batch_size
        self.num_workers = num_workers


    def labeled_dataloader(self, j):
        
        transform = Compose([
            RandAugment(),
            RandomResizedCrop(224, scale=(1.0, 1.0)),
            ToTensor()
        ])
        labeled_dataset = ImageFolder(self.labeled_train_dataset_path, transform=self.train_transform)
        train_dataset, val_dataset = torch.utils.data.random_split(
            labeled_dataset,
            [
                int(0.8 * len(labeled_dataset)),
                len(labeled_dataset) - int(0.8 * len(labeled_dataset)),
            ],
            generator=torch.Generator().manual_seed(3300 + j),
        )
        labeled_augmented_dataset =[ ImageFolder(self.labeled_train_dataset_path, transform=transform) for i in range(10)]
        augmented_dataset = [ torch.utils.data.random_split(
            dataset,
            [
                int(0.8 * len(dataset)),
                len(dataset) - int(0.8 * len(dataset)),
            ],
            generator=torch.Generator().manual_seed(3300+j),
        )[0] for dataset in labeled_augmented_dataset ]

        
        train_combined_dataset_temp = torch.utils.data.dataset.ConcatDataset([augmented_dataset[i] for i in range (10)])
        train_combined_dataset = torch.utils.data.dataset.ConcatDataset([train_combined_dataset_temp,train_dataset ])
        traindataloader = DataLoader(
            train_combined_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

        valdataloader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )
        
        return traindataloader, valdataloader

    def unlabeled_dataloader(self):
        unlabeled_dataset = UnlabeledDataset(self.unlabeled_train_dataset_path, transform=self.train_transform)
        unlabeleddataloader =  DataLoader(
            unlabeled_dataset,
            batch_size=28,
            shuffle=True,
            num_workers=self.num_workers,
        )
        
        return unlabeleddataloader


class UnlabeledDataset(Dataset):
    def __init__(self, dataset_path, transform=None):
        self.dataset_path = dataset_path
        self.transform = transform
        self.data = self.load_images()

    def load_images(self):
        data = []
        image_files = os.listdir(self.dataset_path)
        n = 0
        for file_name in image_files:
            file_path = os.path.join(self.dataset_path, file_name)
            image = Image.open(file_path).convert("RGB")
            data.append(image)
            if n >60000:
                return data
            n +=1
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, -1  # Use a pseudo-label of -1 for unlabeled samples



In [None]:
import random
import numpy as np

def blend_images(image1, image2, f):
    # Ensure the images have the same shape
    assert image1.shape == image2.shape, "Images must have the same shape"

    # Perform the image blending
    blended_image = (1 - f) * image1 + f * image2
    
    return blended_image

def blend_labels(label1,label2,f,max_l):
    label = np.zeros(max_l)
    label[label1] = (1 - f)
    label[label2] = f
    return label
    
    

def blend(image1, image2, label1, label2, alpha = 0.2,max_l = 48):
    param_rand = random.betavariate(alpha, alpha)
    blended_image = blend_images(image1, image2, param_rand)
    blended_label = blend_labels(label1, label2 , param_rand, max_l)
    return blended_image, blended_label


In [None]:
import torch
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

In [None]:
def create_augmented_dataset(dataset,nb_imgs):
    images = []
    labels = []
    for i in range(nb_imgs):
        image_indices = random.sample(range(len(dataset)), 2)
        image1, label1 = dataset[image_indices[0]]
        image2, label2 = dataset[image_indices[1]]
        blended_image, blended_label = blend(image1, image2, label1, label2)
        images.append(blended_image)
        labels.append(blended_label)
    return ImageDataset(images,labels)

In [None]:
datamodule = DataModule("compressed_dataset/train", "compressed_dataset/unlabelled", torchvision.transforms.Compose([torchvision.transforms.Resize(size=[224, 224]), torchvision.transforms.ToTensor()]), 32, 48)  

In [None]:
train_loader, val_loader = datamodule.labeled_dataloader(0)

In [None]:
train_dataset = create_augmented_dataset(train_set,20000)

In [None]:
train_loader = DataLoader(
            train_dataset,
            batch_size=64,
            shuffle=False,
            num_workers=4,
        )

In [None]:
val_loader = DataLoader(
            val,
            batch_size=64,
            shuffle=False,
            num_workers=4,
        )

In [None]:
unlabeled_train_loader = datamodule.unlabeled_dataloader()

In [None]:
def freematch_augmentation(tensor_batch):
    # Define the augmentation transforms
    tensor_batch = transforms.functional.adjust_contrast(tensor_batch, 1.2)  
    #tensor_batch = transforms.functional.gaussian_blur(tensor_batch, kernel_size=(3, 3))

    # Additional augmentation transforms
    tensor_batch = transforms.functional.adjust_brightness(tensor_batch, 0.8)
    tensor_batch = transforms.functional.affine(tensor_batch, angle=10, translate=(0.2, 0.2), scale=0.8, shear = 0.)
    tensor_batch = transforms.functional.hflip(tensor_batch)
    #tensor_batch = transforms.functional.normalize(tensor_batch, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    

    return tensor_batch


# Models

In [None]:
class ResNetFinetune(nn.Module):
    def __init__(self, num_classes, frozen=False):
        super().__init__()
        self.backbone = torchvision.models.resnet50(pretrained=True)
        self.backbone.fc = nn.Identity()
        if frozen:
            for param in self.backbone.parameters():
                param.requires_grad = False
        self.classifier = nn.Linear(2048, num_classes)
        #self.load_model_weights("model_pretrain.pt")
        
    def load_model_weights(self, model_path):
        state_dict = torch.load(model_path)
        self.backbone.load_state_dict(state_dict)
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x



In [None]:
import torchvision
import torch.nn as nn
import timm

class VisionFinetune(nn.Module):
    def __init__(self, frozen = False):
        super().__init__()
        self.backbone = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Linear(1000, 48)
        if frozen:
            for param in self.backbone.parameters():
                param.requires_grad = False
        #self.load_model_weights("model_pretrain.pt")
        
    def load_model_weights(self, model_path):
        state_dict = torch.load(model_path)
        self.backbone.load_state_dict(state_dict)
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x

# Set up

In [None]:
logger = wandb.init(project="challenge", name="run_vision")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionFinetune().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
loss_fn = loss_fn = lambda input, target: torch.nn.functional.cross_entropy(input, target, reduction='mean', label_smoothing=0.1)

# Fine tunning

In [None]:
for j in range(7, 8):

    train_loader, val_loader = datamodule.labeled_dataloader(j//2)
    model = VisionFinetune().to(device)
    optimizer = torch.optim.ASGD(model.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0., foreach=None)

    for epoch in tqdm(range(1)):
        model.train()
        epoch_loss = 0
        epoch_num_correct = 0
        num_samples = 0
        accumulated_loss = 0
    
            # Labeled data
        for i, batch in enumerate(train_loader):
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            preds = model(images)
            loss = loss_fn(preds, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_loss += loss.detach().cpu().numpy() * len(images)
            epoch_num_correct += (
                (preds.argmax(1) == labels).sum().detach().cpu().numpy()
                )
            num_samples += len(images)
            
        epoch_loss /= num_samples
        epoch_acc = epoch_num_correct / num_samples
        logger.log(
                {
                    "epoch": epoch,
                    "train_loss_epoch": epoch_loss,
                    "train_acc": epoch_acc,
                }
            )
    
            # Validation
        if epoch % 1 == 0:  # Evaluate the model every 5 epochs
            model.eval()
            epoch_loss = 0
            epoch_num_correct = 0
            num_samples = 0
    
            with torch.no_grad():
                for batch_idx, batch in enumerate(val_loader):
                    images, labels = batch
                    images = images.to(device)
                    labels = labels.to(device)
                    preds = model(images)
                    loss = loss_fn(preds, labels)
                    epoch_loss += loss.item() * len(images)
                    epoch_num_correct += (preds.argmax(1) == labels).sum().item()
                    num_samples += len(images)
    
            epoch_loss /= num_samples
            epoch_acc = epoch_num_correct / num_samples
            logger.log(
                {
                    "epoch": epoch,
                    "val_loss_epoch": epoch_loss,
                    "val_acc": epoch_acc,
                }
            )
        torch.save(model.state_dict(), "model"+str(j)+".pt")

In [None]:
torch.save(model.state_dict(), "model-1.pt")

# Free Match

In [None]:
import torch
from tqdm import tqdm

decay = 0.999
t = 0.8


grad_accumulation_steps = 4  # Accumulate gradients over 4 batches
accumulated_loss = 0
num_samples = 0

for epoch in tqdm(range(3)):
    model.train()
    epoch_loss = 0
    epoch_num_correct = 0
 
    # Labeled data
    for batch_idx, (data_l, data_ul) in enumerate(zip(train_loader, unlabeled_train_loader)):
        images_l, labels = data_l
        images_l = images_l.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        preds_l = model(images_l)
        Ls = loss_fn(preds_l, labels)
        
        with torch.no_grad():
            images_ul = data_ul[0]
            images_ul = images_ul.to(device)
            preds_ul = model(images_ul)

            condition = (torch.nn.functional.softmax(preds_ul, dim=1).max(dim=1)[0] > t)
        
        if len(images_ul[condition]) != 0:
            indices = torch.argmax(preds_ul[condition], dim=1).to(device)
            img = images_ul[condition].reshape(len(images_ul[condition]), 3, 224, 224) * 255.0  # case [0, 1]
            img = torch.clip(img, 0.0, 255.0)
            img = img.type(torch.uint8)
            img = RandAugment()(img)
            img = img.type(torch.float32) / 255.0
            preds_aug_ul = model(img)
        
            Lu = loss_fn(preds_aug_ul, indices)

            loss = Lu + Ls
        else:
            loss = Ls

        # Gradient accumulation
        accumulated_loss += loss
        num_samples += len(images_l)
        epoch_num_correct += (preds_l.argmax(1) == labels).sum().item()

        if (batch_idx + 1) % grad_accumulation_steps == 0:
            accumulated_loss /= grad_accumulation_steps
            accumulated_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            epoch_loss += accumulated_loss.item() * num_samples

            accumulated_loss = 0
            num_samples = 0

    epoch_loss /= len(train_loader.dataset)
    epoch_acc = epoch_num_correct / len(train_loader.dataset)
    logger.log({
        "train_loss_epoch": epoch_loss,
        "train_acc": epoch_acc,
    })

    # Validation
    if epoch % 1 == 0:
        model.eval()
        epoch_loss = 0
        epoch_num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                images, labels = batch
                images = images.to(device)
                labels = labels.to(device)
                preds = model(images)
                loss = loss_fn(preds, labels)
                epoch_loss += loss.item() * len(images)
                epoch_num_correct += (preds.argmax(1) == labels).sum().item()
                

        epoch_loss /= len(val_loader.dataset)
        epoch_acc = epoch_num_correct / len(val_loader.dataset)
        logger.log({
            "val_loss_epoch": epoch_loss,
            "val_acc": epoch_acc,
        })
    torch.save(model.state_dict(), "model.pt")
        


# Fix Match

In [None]:
import torch
from tqdm import tqdm

decay = 0.999
t = 0.8


grad_accumulation_steps = 4  # Accumulate gradients over 4 batches
accumulated_loss = 0
num_samples = 0

for epoch in tqdm(range(3)):
    model.train()
    epoch_loss = 0
    epoch_num_correct = 0
 
    # Labeled data
    for batch_idx, (data_l, data_ul) in enumerate(zip(train_loader, unlabeled_train_loader)):
        images_l, labels = data_l
        images_l = images_l.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        preds_l = model(images_l)
        Ls = loss_fn(preds_l, labels)
        
        with torch.no_grad():
            images_ul = data_ul[0]
            images_ul = images_ul.to(device)
            preds_ul = model(images_ul)

            condition = (torch.nn.functional.softmax(preds_ul, dim=1).max(dim=1)[0] > t)
        
        if len(images_ul[condition]) != 0:
            indices = torch.argmax(preds_ul[condition], dim=1).to(device)
            img = images_ul[condition].reshape(len(images_ul[condition]), 3, 224, 224) * 255.0  # case [0, 1]
            img = torch.clip(img, 0.0, 255.0)
            img = img.type(torch.uint8)
            img = RandAugment()(img)
            img = img.type(torch.float32) / 255.0
            preds_aug_ul = model(img)
        
            Lu = loss_fn(preds_aug_ul, indices)

            loss = Lu + Ls
        else:
            loss = Ls

        # Gradient accumulation
        accumulated_loss += loss
        num_samples += len(images_l)
        epoch_num_correct += (preds_l.argmax(1) == labels).sum().item()

        if (batch_idx + 1) % grad_accumulation_steps == 0:
            accumulated_loss /= grad_accumulation_steps
            accumulated_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            epoch_loss += accumulated_loss.item() * num_samples

            accumulated_loss = 0
            num_samples = 0

    epoch_loss /= len(train_loader.dataset)
    epoch_acc = epoch_num_correct / len(train_loader.dataset)
    logger.log({
        "train_loss_epoch": epoch_loss,
        "train_acc": epoch_acc,
    })

    # Validation
    if epoch % 1 == 0:
        model.eval()
        epoch_loss = 0
        epoch_num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                images, labels = batch
                images = images.to(device)
                labels = labels.to(device)
                preds = model(images)
                loss = loss_fn(preds, labels)
                epoch_loss += loss.item() * len(images)
                epoch_num_correct += (preds.argmax(1) == labels).sum().item()
                

        epoch_loss /= len(val_loader.dataset)
        epoch_acc = epoch_num_correct / len(val_loader.dataset)
        logger.log({
            "val_loss_epoch": epoch_loss,
            "val_acc": epoch_acc,
        })
    torch.save(model.state_dict(), "model.pt")
        
