<a href="https://colab.research.google.com/github/FrancescaMusella/MLDL-Project/blob/main/centralized_finalversion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import copy

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [None]:
#transformation of CIFAR100
transform = T.Compose([
    T.Resize((32, 32)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
from torchvision.datasets import CIFAR100
train_val=CIFAR100(root='.data/', train=True, download=True, transform=transform)
test=CIFAR100(root='.data/', train=False, download=True, transform=transform)

In [None]:
#train-validation-test split
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split

targets = train_val.targets

train_indices, val_indices = train_test_split(
    range(len(targets)),
    test_size=0.2,
    stratify=targets,
    random_state=42
)

train = torch.utils.data.Subset(train_val, train_indices)
val = torch.utils.data.Subset(train_val, val_indices)

train_loader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)
val_loader= torch.utils.data.DataLoader(val, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=32, shuffle=False)

In [None]:
#ViT-S/16
!git clone https://github.com/facebookresearch/dino.git
!ls

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
vits16_original = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=True).to(device)
vits16_new=copy.deepcopy(vits16_original)
print(vits16_new)

In [None]:
#change of the head and freezing of layers
vits16_new.head = torch.nn.Linear(in_features=384,
                    out_features=100,
                    bias=True).to(device)

for name, param in vits16_new.named_parameters():
    if "head" not in name and "patch_embed" not in name and 'proj' not in name and 'pos_drop' not in name and 'attn' not in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

vits16_self_sgd=copy.deepcopy(vits16_new)
vits16_sparse=copy.deepcopy(vits16_new)

print(vits16_new)

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
def train(epoch, model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.cuda(), targets.cuda()

        intermediate_output = model.get_intermediate_layers(inputs, n=1)
        features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        outputs = model.head(features)

        loss=criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    train_loss = running_loss / len(train_loader)
    train_accuracy = 100. * correct / total
    print(f'Train Epoch: {epoch} Loss: {train_loss:.6f} Acc: {train_accuracy:.2f}%')

In [None]:
def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0

    correct, total = 0, 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.cuda(), targets.cuda()

            intermediate_output = model.get_intermediate_layers(inputs, n=1)
            features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
            outputs = model.head(features)
            loss=criterion(outputs, targets)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    val_loss = val_loss / len(val_loader)
    val_accuracy = 100. * correct / total

    print(f'Validation Loss: {val_loss:.6f} Acc: {val_accuracy:.2f}%')
    return val_accuracy, val_loss

In [None]:
#Vanilla SGDM grid search
best_acc = 0
best_param=[0, 0, 0]

learning_rates = [1e-3, 1e-4, 5e-4]
weight_decay_values=[1e-3, 1e-4, 5e-4]
num_epochs = 10

lr_counter=0
mt_counter=0
x_acc=torch.zeros(len(learning_rates),len(weight_decay_values),num_epochs)
x_loss=torch.zeros(len(learning_rates),len(weight_decay_values),num_epochs)
for lr in learning_rates:
    for weight_decay in weight_decay_values:
        model=copy.deepcopy(vits16_new)
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
        scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        for epoch in range(1, num_epochs + 1):
            train(epoch, model, train_loader, loss_fn, optimizer)
            scheduler.step()
            val_accuracy, val_loss = validate(model, val_loader, loss_fn)

            if val_accuracy > best_acc:
                best_acc = val_accuracy
                best_param[0]=lr
                best_param[1]=weight_decay
                best_param[2]=epoch


            x_acc[lr_counter,mt_counter,epoch-1]=val_accuracy
            x_loss[lr_counter,mt_counter,epoch-1]=val_loss
        mt_counter+=1
    mt_counter=0
    lr_counter+=1

print(f'Best validation accuracy: {best_acc:.2f}%')
print(f'Best learning rate: {best_param[0]}')
print(f'Weight decay: {best_param[1]}')
print(f'Best epoch: {best_param[2]}')

In [None]:
colors=['green', 'orange', 'blue']
labels=['lr=1e-3', 'lr=1e-4', 'lr=5e-4']
n_epoch=10
fig, (ax1, ax2, ax3) = plt.subplots( 1, 3,figsize=(10,5))

for i in range (0,3):
  ax1.plot(x_acc[i,0,:], label = labels[i], color=colors[i])
ax1.set_xticks(ticks=np.arange(n_epoch), labels=np.arange(1, n_epoch + 1))
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.grid(True)
ax1.set_title('Val_accuracy Weight decay=1e-3')
ax1.legend()
for i in range (0,3):
  ax2.plot(x_acc[i,1,:], label = labels[i], color=colors[i])
ax2.set_xticks(ticks=np.arange(n_epoch), labels=np.arange(1, n_epoch + 1))
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.grid(True)
ax2.set_title('Val_accuracy Weight decay=1e-4')
ax2.legend()
for i in range (0,3):
  ax3.plot(x_acc[i,2,:], label = labels[i], color=colors[i])
ax3.set_xticks(ticks=np.arange(n_epoch), labels=np.arange(1, n_epoch + 1))
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Accuracy')
ax3.grid(True)
ax3.set_title('Val_accuracy Weight decay=5e-4')
ax3.legend()
plt.subplots_adjust(wspace=0.3)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots( 1, 3,figsize=(10,5))
for i in range (0,3):
  ax1.plot(x_loss[i,0,:], label = labels[i], color=colors[i])
ax1.set_xticks(ticks=np.arange(n_epoch), labels=np.arange(1, n_epoch + 1))
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True)
ax1.set_title('Val_loss Weight decay=1e-3')
ax1.legend()
for i in range (0,3):
  ax2.plot(x_loss[i,1,:], label = labels[i], color=colors[i])
ax2.set_xticks(ticks=np.arange(n_epoch), labels=np.arange(1, n_epoch + 1))
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.grid(True)
ax2.set_title('Val_loss Weight decay=1e-4')
ax2.legend()
for i in range (0,3):
  ax3.plot(x_loss[i,2,:], label = labels[i], color=colors[i])
ax3.set_xticks(ticks=np.arange(n_epoch), labels=np.arange(1, n_epoch + 1))
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Loss')
ax3.grid(True)
ax3.set_title('Val_loss Weight decay=5e-4')
ax3.legend()
plt.subplots_adjust(wspace=0.3)

In [None]:
#CosineAnnealing scheduler
num_epoch=20
vits16=copy.deepcopy(vits16_new)
optimizer = torch.optim.SGD(vits16.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)

scheduler_cosine =torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

for epoch in range(1, num_epoch):
  train(epoch, vits16, train_loader, loss_fn, optimizer)
  val_accuracy, val_loss = validate(vits16, val_loader, loss_fn)
  scheduler_cosine.step()

In [None]:
#Linear scheduler
vits16=copy.deepcopy(vits16_new)
optimizer = torch.optim.SGD(vits16.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)

scheduler_linear = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=200)

for epoch in range(1, num_epoch):
  train(epoch, vits16, train_loader, loss_fn, optimizer)
  val_accuracy, val_loss = validate(vits16, val_loader, loss_fn)
  scheduler_linear.step()

In [None]:
#Exponential scheduler
vits16=copy.deepcopy(vits16_new)
optimizer = torch.optim.SGD(vits16.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)

scheduler_exp = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.9)

for epoch in range(1, num_epoch):
  train(epoch, vits16, train_loader, loss_fn, optimizer)
  val_accuracy, val_loss = validate(vits16, val_loader, loss_fn)
  scheduler_exp.step()

In [None]:
#StepLR scheduler
vits16=copy.deepcopy(vits16_new)
optimizer = torch.optim.SGD(vits16.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)

scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma = 0.9)

for epoch in range(1, num_epoch):
  train(epoch, vits16, train_loader, loss_fn, optimizer)
  val_accuracy, val_loss = validate(vits16, val_loader, loss_fn)
  scheduler_step.step()

In [None]:
#Test with the best hyper-parameters and scheduler
n_epoch=50
vits16=copy.deepcopy(vits16_new)
optimizer = torch.optim.SGD(vits16.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
test_loss_vector=[]
for epoch in range(1, n_epoch + 1):
       train(epoch, vits16, train_loader, loss_fn, optimizer)
       scheduler.step()
       test_accuracy, test_loss = validate(vits16, test_loader, loss_fn)
       test_loss_vector.append(test_loss)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(test_loss_vector, label='Test Loss')
ticks = [0] + [i for i in range(4, n_epoch, 5) if i != 0]
labels = [1] + [i + 1 for i in range(4, n_epoch, 5) if i != 0]
plt.xticks(ticks=ticks, labels=labels)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test_loss Weight decay=5e-4, Learning rate=1e-4, Momentum=0.9')
plt.legend()
plt.show()

In [None]:
#creation of the mask
def compute_fisher_mask(model, dataloader, sparsity, criterion):
  fisher_scores = {}
  prev_mask = {}

  model.eval()

  for param in model.parameters():
      if param.requires_grad:
          fisher_scores[param] = torch.zeros_like(param.data)
          prev_mask[param] = torch.ones_like(param.data)

  for round in range(5):
    for param in fisher_scores:
        fisher_scores[param].zero_()

    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()

        intermediate_output = model.get_intermediate_layers(inputs, n=1)
        features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        outputs = model.head(features)

        loss = criterion(outputs, targets)

        model.zero_grad()
        loss.backward()

        for param in model.parameters():
            if param.requires_grad and param.grad is not None:
              fisher_scores[param] += (param.grad.data.pow(2) * prev_mask[param])

    new_mask = {}
    all_scores = torch.cat([torch.flatten(v) for v in fisher_scores.values()])
    non_zero_scores=all_scores[all_scores!=0]
    k = int(sparsity[round] * non_zero_scores.numel())
    threshold, _ = torch.kthvalue(non_zero_scores, non_zero_scores.numel()-k)

    for param, score in fisher_scores.items():
        masked_score = score * prev_mask[param]
        current_mask = ((masked_score < threshold) * prev_mask[param]).float()
        new_mask[param] = current_mask
        prev_mask[param] = new_mask[param]

        param_to_name = {param: name for name, param in model.named_parameters()}

  for param, mask in new_mask.items():
      if torch.any(mask == 1):
          print(param_to_name[param])

  zero_count = sum((v == 0).sum().item() for v in new_mask.values())
  one_count = sum((v == 1).sum().item() for v in new_mask.values())

  print(f"Zeros: {zero_count}, Ones: {one_count}")

  return new_mask

In [None]:
#self implementation of SGDM with the addition of the mask
def sgdm_sparse (params, lr, momentum, dampening, weight_decay, nesterov, maximize,b, mask):
    for param in params:
        if param.grad is None:
            continue
        grad = param.grad.data

        if weight_decay!= 0:
          grad=grad+weight_decay*param.data

        if param not in b:
          b[param] = torch.zeros_like(param.data)

        if momentum!=0:
            b_toll = b[param]
            b_new = momentum * b_toll + (1 - dampening) * grad
            if nesterov:
               update=grad+momentum*b_new
            else:
              update=b_new
        else:
           update=grad
           b_new=0

        update = update * mask[param]

        if maximize:
          param.data=param.data+lr*update
          b[param] = b_new
        else:
          param.data=param.data-lr*update
          b[param] = b_new
    return b

In [None]:
def train_sgd_sparse(epoch, model, train_loader, criterion,sparsity, lr, momentum, weight_decay, mask):
     model.train()
     running_loss = 0.0
     correct = 0
     total = 0
     params=list(model.parameters())
     dampening=0
     nesterov=False
     maximize=False
     b={}

     for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.cuda(), targets.cuda()

        intermediate_output = model.get_intermediate_layers(inputs, n=1)
        features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        outputs = model.head(features)

        loss=criterion(outputs, targets)
        model.zero_grad()
        loss.backward()

        b=sgdm_sparse(params, lr, momentum, dampening, weight_decay, nesterov, maximize,b, mask)

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

     train_loss = running_loss / len(train_loader)
     train_accuracy = 100. * correct / total
     print(f'Train Epoch: {epoch} Loss: {train_loss:.6f} Acc: {train_accuracy:.2f}%')

In [None]:
#pre-training only the head
num_epochs=5

vits16_sparse=copy.deepcopy(vits16_new)
optimizer = torch.optim.SGD(vits16_sparse.head.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
for epoch in range(1, num_epochs + 1):
  train(epoch, vits16_sparse, train_loader, loss_fn, optimizer)
  validate(vits16_sparse, test_loader, loss_fn)

In [None]:
#Mask computation
train_subset = torch.utils.data.Subset(train_val, train_indices)
num_samples = int(0.1 * len(train_subset))
small_train_indices = list(range(num_samples))
small_train = torch.utils.data.Subset(train_subset, small_train_indices)
small_train_loader = torch.utils.data.DataLoader(small_train, batch_size=1, shuffle=True)

sparsity=[0.1, 0.2, 0.3, 0.4, 0.66]
mask = compute_fisher_mask(vits16_sparse, small_train_loader, sparsity, loss_fn)

In [None]:
#Test Sparse SGDM
test_loss_vector_sparse=[]
num_epochs=200
for epoch in range(1, num_epochs+1):
    train_sgd_sparse(epoch, vits16_sparse, train_loader, loss_fn,sparsity, lr=1e-4, momentum=0.9, weight_decay=5e-4, mask=mask)
    test_accuracy_sparse, test_loss_sparse = validate(vits16_sparse, test_loader, loss_fn)
    test_loss_vector_sparse.append(test_loss_sparse)

In [None]:
plt.figure(figsize=(12, 5))
plt.plot(test_loss_vector_sparse, label='Test Loss')
ticks = [0] + [i for i in range(4, num_epochs, 5) if i != 0]
labels = [1] + [i + 1 for i in range(4, num_epochs, 5) if i != 0]
plt.xticks(ticks=ticks, labels=labels, fontsize=8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.title('Test_loss Weight decay=5e-4, Learning rate=1e-4, Momentum=0.9')
plt.legend()
plt.show()