# Rexnet

# Imports

In [None]:
import torchvision
import torch
from torchvision import transforms
import timm
import torch.nn.functional as F
from torch import nn
from tqdm.notebook import tqdm
import wandb
import numpy as np 
import random 
import os 

# Utils

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def move_to(obj, device):
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, dict):
        res = {}
        for k, v in obj.items():
            res[k] = move_to(v, device)
        return res
    elif isinstance(obj, list):
        res = []
        for v in obj:
            res.append(move_to(v, device))
        return res
    else:
        raise TypeError("Invalid type for move_to")

In [None]:
def pretrain_loss(z1, z2, temperature=0.5):
    batch_size = z1.shape[0]
    out = torch.cat((z1, z2), dim=0)
    sim_matrix = torch.exp(torch.mm(out, out.t()) / temperature)
    mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()
    sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)
    pos_sim = torch.exp(torch.sum(z1 * z2, dim=-1) / temperature)
    pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
    loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()
    return loss

In [None]:
def test_pretrain_loss():
    z1 = torch.rand(2, 128)
    z2 = torch.rand(2, 128)
    loss = pretrain_loss(z1, z2)
    print(loss)

In [None]:
test_pretrain_loss()

In [None]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
test_transform = transforms.Compose([
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

In [None]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

# Dataset

In [None]:
#Stupid thing wont load... 
import ssl 
ssl._create_default_https_context = ssl._create_unverified_context

ds_train = torchvision.datasets.CIFAR10('data', download=True, transform = transforms.ToTensor())

ds_test = torchvision.datasets.CIFAR10('data', train=False, download=True, transform = transforms.ToTensor())


# Model

In [None]:
class SimModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = timm.create_model('rexnet_150', num_classes=1) #  pretrained=True,
        self.base.head.fc = nn.Identity()
        self.fc1 = nn.Linear(1920, 512)
        self.fc2 = nn.Linear(512, 128)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.base(x)
        h = self.fc1(x)
        z = self.relu(h)
        z = self.fc2(h)

        h = F.normalize(h, dim=1)
        z = F.normalize(z, dim=1)
        return x, h, z

In [None]:
def eval_pretrain_fn(model, dl_test, dl_train, k=200):
    n_samples = 0
    #n_correct_x = 0
    n_correct_q = 0
    n_correct_q_k = 0
    n_correct_z = 0
    n_correct_z_k = 0
    model.eval()
    with torch.no_grad():
        train_preds_q = []
        train_preds_z = []
        labels = []
        for batch in tqdm(dl_train):
            X,y = batch
            X = move_to(X, DEVICE)
            y = move_to(y, DEVICE)
            
            X = test_transform(X)
            
            output_x, output_q , output_z = model(X)
            
            train_preds_z.append(output_z)
            train_preds_q.append(output_q)
            
            labels.append(y)
        
        train_preds_q = torch.cat(train_preds_q, dim=0)
        train_preds_z = torch.cat(train_preds_z, dim=0)
        labels = torch.cat(labels, dim=0)
        
        for batch in tqdm(dl_test):
            X, y = batch
            X = move_to(X, DEVICE)
            y = move_to(y, DEVICE)
            
            X = test_transform(X)
            
            output_x, output_q , output_z = model(X)
            
            n = len(y)
            n_samples += n
            
            distances_q = train_preds_q @ torch.transpose(output_q, 0, 1)
            distances_z = train_preds_z @ torch.transpose(output_z, 0, 1)

            distance_q, sim_indices_q = distances_q.topk(k=k, dim=0)
            
            sim_labels_q = torch.gather(labels.expand(sim_indices_q.shape[0], -1), dim=-1, index=sim_indices_q)

            _, sim_indices_z = distances_z.topk(k=k, dim=0)

            sim_labels_z = torch.gather(labels.expand(sim_indices_z.shape[0], -1), dim=-1, index=sim_indices_z)
            

            
            max_labels_q, _ = torch.mode(sim_labels_q, dim=0)


            correct_q_k = torch.sum(max_labels_q == y)
            
            max_labels_z, _ = torch.mode(sim_labels_z, dim=0)

            correct_z_k = torch.sum(max_labels_z == y)

            n_correct_q_k += correct_q_k
            n_correct_z_k += correct_z_k



    accuracy_q_k = n_correct_q_k / n_samples
    accuracy_z_k = n_correct_z_k / n_samples
    return accuracy_q_k, accuracy_z_k

# WandB

In [None]:
batch_size_train = 90
batch_size_test = 40
lr = 1e-3
SEED = 42

### Weights and Biases not necessary can comment

params = {
    'batch_size_train': batch_size_train,
    'batch_size_test': batch_size_test,
    'learning_rate': lr,
    'seed': SEED
}

wandb.init(project="DL", entity="pydqn", config=params, reinit=True)
wandb.define_metric('train_loss', summary='min')
wandb.define_metric('evaluation_accuracy', summary='max')
run_name = wandb.run.name
config = wandb.config

###

seed_everything(SEED)



# Main

In [None]:
model = SimModel()
model.cuda()
# model.load_state_dict(torch.load("best_model_rexnet.pt"))
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = pretrain_loss
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size_train, shuffle=True, num_workers=2, pin_memory=True)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size_test, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
cumsum_q = 0 
cumsum_z = 0 
for i in tqdm(range(500)):
    total_loss = 0
    model.train()
    for batch in tqdm(dl_train):
        X, y = batch
        X = move_to(X, DEVICE)
        
        X1 = train_transform(X)
        X2 = train_transform(X)

        _, _, out1 = model(X1)
        _, _, out2 = model(X2)

        loss = criterion(out1, out2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss

    q_k_acc, z_k_acc = eval_pretrain_fn(model, dl_test, dl_train)
    print("total loss")
    print(total_loss)
    print(f"q k acc : {q_k_acc.item()}")
    print(f"z k acc : {z_k_acc.item()}")
    
    cumsum_q += q_k_acc
    cumsum_z += z_k_acc
    
    if z_k_acc > best_acc:
        best_acc = z_k_acc
        torch.save(model.state_dict(), f"best_model_rexnet150.pt")
    
    wandb.log({
            'epoch': i,
            'train_loss': total_loss / batch_size_train,
            'test_acc_q': q_k_acc,
            'test_acc_z': z_k_acc
        
        })
    
    if i % 10 == 9: 
        wandb.log({
            'average_acc_h': cumsum_q / 10,
            'average_acc_z': cumsum_z / 10
        })
        cumsum_q = 0
        cumsum_z = 0