#SimCLR_MoCo_SSL

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader, random_split, Subset
from torchsummary import summary
import numpy as np
import random

In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cpu


In [None]:
# Hyperparameters
batch_size = 256  # Batch size for contrastive learning
batch_size_fine_tune = 32  # Batch size for fine-tuning
learning_rate = 0.001  # Learning rate
num_epochs = 100  # Number of epochs for contrastive learning
num_epochs_fine_tune = 10  # Number of epochs for fine-tuning
temperature = 0.5  # Temperature for NT-Xent loss

In [None]:
# Data Transformations   (داده افزایی)      (agmention)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),                    #رندوم کراپ زدن با سایز 224
    transforms.RandomHorizontalFlip(),                            #اینه کردن به روش هریزون تا
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),                  #تغییر رنگ
    transforms.RandomGrayscale(p=0.2),                              #با احتمال 20 درصد تصویر رو سیاه سفید میکنه
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

این 4 تا اگمنتیشن ما میتونیم بیشتر هم اگمنتیشن یا همون داده افزایی رو داشته باشیم

In [None]:
# Load CIFAR-10 Dataset
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:01<00:00, 98.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
# Model definition for SimCLR
class SimCLR(nn.Module):
    def __init__(self, base_model, out_dim):
        super(SimCLR, self).__init__()
        self.base_model = base_model        #مدل پایه که میتونه مثلا رزنت یا هرچی باشه
        self.features = nn.Sequential(*list(base_model.children())[:-1])    #ویژگی هامون
        self.projection_head = nn.Sequential(                           #تابع g ما که دوتا لایه خطی
            nn.Linear(base_model.fc.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, out_dim)
        )

    def forward(self, x):
        h = self.features(x).squeeze()
        z = self.projection_head(h)
        return h, z

In [None]:
# NT-Xent loss function                             #تابع ضرر و شباهت کسینوسی
def nt_xent_loss(z_i, z_j, temperature):
    N = z_i.shape[0]
    z = torch.cat((z_i, z_j), dim=0)
    sim = torch.mm(z, z.t()) / temperature
    sim_i_j = torch.diag(sim, N)
    sim_j_i = torch.diag(sim, -N)
    positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0)
    labels = torch.arange(N, device=z.device).repeat(2)
    loss = nn.CrossEntropyLoss()(sim, labels)
    return loss

In [None]:

# Training function for SimCLR
def train_simclr(model, train_loader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, _) in enumerate(train_loader):         #اینجا هم مثل یادگیری خود نظارتی برچسب نیاز نداریم
            optimizer.zero_grad()
            inputs = torch.cat([inputs, inputs], dim=0).to(device)      #concat input   256,3,224,224    256,3,244,224
            h, z = model(inputs)  #bordar h,z                                           #512,3,224,224
                                                                                   #x256==x512     x255==x511 ,...
            z_i, z_j = torch.split(z, batch_size, dim=0)         #میاد بردار زد که پونصد و دوازدتایی رو میشکونه به دوتا برداره(خط پایین)
            loss = nt_xent_loss(z_i, z_j, temperature)                    #  z_i=256,d,  z_j=256,d
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if (i + 1) % 10 == 0:
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 10:.4f}')
                running_loss = 0.0

In [None]:
# MoCo Model definition
class MoCo(nn.Module):
    def __init__(self, base_encoder, out_dim, K=4096, m=0.99, T=0.07):
        super(MoCo, self).__init__()
        self.K = K
        self.m = m
        self.T = T

        # Create the encoders
        self.encoder_q = base_encoder(num_classes=out_dim)
        self.encoder_k = base_encoder(num_classes=out_dim)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        self.queue = nn.functional.normalize(torch.randn(out_dim, K), dim=0)
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.queue_ptr = 0

    @torch.no_grad()
    def momentum_update_key_encoder(self):
        # update the key encoder
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr = ptr

    def forward(self, im_q, im_k):
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        with torch.no_grad():  # no gradient to keys
            self.momentum_update_key_encoder()  # update the key encoder
            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.T

        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(q.device)
        self.dequeue_and_enqueue(k)

        return logits, labels


# Training function for MoCo
def train_moco(model, train_loader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, _) in enumerate(train_loader):
            optimizer.zero_grad()
            inputs = torch.cat([inputs, inputs], dim=0).to(device)
            im_q, im_k = torch.split(inputs, batch_size, dim=0)
            logits, labels = model(im_q, im_k)
            loss = nn.CrossEntropyLoss()(logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if (i + 1) % 10 == 0:
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 10:.4f}')
                running_loss = 0.0