In [None]:
import pandas as pd
from torch.utils.data import DataLoader, Subset
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm
import time
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import wandb

In [None]:
model_path = ['epoch_8_best_model.pt','resnet50_dist_against_non_fine_tuned_23_07.pth','resnet50_dist_against_finetuned_clip_24_07_2024.pth']
wandb.init(project=model_path[0]+'0.001LR')  # Replace with your project name

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.get_device_name(device))

In [None]:
dataset_dir = "/home/user1/ariel/fed_learn/large_vlm_distillation_ood/s_cars_ood_ind_test_test_val/"

In [None]:
train_tfms = transforms.Compose([transforms.Resize((400, 400)),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.RandomRotation(15),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
test_tfms = transforms.Compose([transforms.Resize((400, 400)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_ind_tfms = transforms.Compose([transforms.Resize((400, 400)),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.RandomRotation(15),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# training is done with 5 samples of each class of the ood data
dataset_train = torchvision.datasets.ImageFolder(root=dataset_dir+"test_to_add/", transform = train_tfms)
trainloader = torch.utils.data.DataLoader(dataset_train, batch_size = 5, shuffle=True, num_workers = 4)
# inference is done with same classes, different instances
dataset_test_ood = torchvision.datasets.ImageFolder(root=dataset_dir+"test/", transform = test_tfms)
testloader_ood = torch.utils.data.DataLoader(dataset_test_ood, batch_size = 50, shuffle=True, num_workers = 4)
#test on train is done on ind data to make sure that the model as saved reconginzes ind data

dataset_test_ind = torchvision.datasets.ImageFolder(root=dataset_dir+"train", transform = test_tfms) #                                                   dataset_dir+"train", transform = test_tfms)
testloader_ind_on_train = torch.utils.data.DataLoader(dataset_test_ind, batch_size = 32, shuffle=True, num_workers = 4)

In [None]:
def train_model(trainloader, testloader_ood,testloader_ind_on_train,  model, criterion, optimizer, lrscheduler,
                n_epochs=5):
    losses = []
    train_ood_accuracies = []
    test_ood_accuracies = []
    test_ind_on_train_accuracies = []

        #wandb.log({"test_ind_on_train": test_ind_on_train})

    for epoch in range(n_epochs):
        model = model.to(device)
        model.eval()
        name = 'test_on_ood'
        test_ood = eval_model(model, testloader_ood, name)
        test_ood_accuracies.append(test_ood)
        wandb.log({"test_ood_acc":  test_ood})
        name2 = 'test_ind_on_train'
        test_ind_on_train = eval_model(model, testloader_ind_on_train,name2)
        test_ind_on_train_accuracies.append(test_ind_on_train)
        wandb.log({"test_ind_on_train": test_ind_on_train})
        since = time.time()
        running_loss = 0.0
        running_correct = 0.0
        model.train()
        for i, data in enumerate(tqdm(trainloader, desc="Training few shots on ood", leave=False)):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            running_correct += (labels == predicted).sum().item()

        epoch_duration = time.time() - since
        epoch_loss = running_loss / len(trainloader)
        epoch_acc = 100 * running_correct / (len(trainloader) * 5)#100/5 * running_correct / len(trainloader) #100/5*running_correct/len(trainloader)# 1
        print(
            f"\nEpoch {epoch + 1}, duration: {epoch_duration:.2f} s, OOD_train_loss: {epoch_loss:.4f}, ood_train acc: {epoch_acc:.2f}")
        wandb.log({"train_ood_loss": epoch_loss, "train_ood_acc": epoch_acc})
        losses.append(epoch_loss)
        train_ood_accuracies.append(epoch_acc)
        # model.eval()
        # name = 'test_on_ood'
        # test_ood = eval_model(model, testloader_ood, name)
        # test_ood_accuracies.append(test_ood)
        # #wandb.log({"epoch": epoch + 1, "test_ood_acc":  test_ood})
        # name2 = 'test_ind_on_train'
        # test_ind_on_train = eval_model(model, testloader_ind_on_train,name2)
        # test_ind_on_train_accuracies.append(test_ind_on_train)
        # #wandb.log({"test_ind_on_train": test_ind_on_train})


        lrscheduler.step(test_ood)
        since = time.time()

    print('Finished Training')
    return model, losses, train_ood_accuracies, test_ood_accuracies, test_ind_on_train_accuracies


In [None]:
def eval_model(model, testloader_ood, name):
    correct = 0.0
    total = 0.0
    testloader_ood = testloader_ood
    with torch.no_grad():
        for i, data in enumerate(tqdm(testloader_ood, desc="test", leave=False)):
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_acc = 100.0 * correct / total
    print(f'\n{name}: {test_acc:.2f}')
    return test_acc


In [None]:
model_dist = torch.load(model_path[0])
print(f'\nmodel was loaded\n')
# model = models.resnet50(pretrained=False)
# num_ftrs = model_ft.fc.in_features
# model_ft.fc = nn.Linear(num_ftrs, 186)
#model.load_state_dict(model_dist, strict=False)
model_ft = model_dist.to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = optim.SGD(model_ft.parameters(), lr=lr,momentum=0.9)
lrscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, threshold = 0.9)
n_epochs = 40

In [None]:
model, losses, train_ood_accuracies, test_ood_accuracies, test_ind_accuracies = train_model(trainloader, testloader_ood,testloader_ind_on_train,  model_ft, criterion, optimizer, lrscheduler,
                n_epochs)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
src = '/home/user1/ariel/fed_learn/large_vlm_distillation_ood/resnet18_classification_on_s_cars_dataset/s_cars_few_shot_train/'

In [None]:
df_path = src+model_path[2]+'test_ood_acc.csv'
df = pd.DataFrame(test_ood_accuracies)
df.to_csv(df_path, header=None, index=False)
df = pd.read_csv(df_path, header=None)

In [None]:
df.max()

In [None]:
x = df.iloc[:,0]
y= df.iloc[:,1]

In [None]:
plt.plot(x,y)
plt.xlabel('num samples')
plt.ylabel('accuracy, %')
plt.savefig(src+'test_accuracies_ood.png')
plt.show()
max_acc = df.iloc[10:,1].max()
print(max_acc)