# Basic Import

In [1]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
import tqdm
import clip
import torch
import torch.nn as nn
import torch.optim as optim
import funct

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset

In [2]:
transform=transforms.Compose([
    transforms.Resize(size=224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3,1,1))
])

resnet_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)  #dataset for resnet training
resnet_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)  #dataset for resnet testing
clip_test =datasets.MNIST(root='./data', train=False, download=True, transform=None)          #dataset for Clip testing

train_dataloader=torch.utils.data.DataLoader(resnet_train,batch_size=50,shuffle=True,num_workers=2)
test_dataloader=torch.utils.data.DataLoader(resnet_test,batch_size=15,shuffle=False,num_workers=2)


# Evaluating CLIP

In [3]:
VISUAL_BACKBONE='RN50x64'
numbers=[0,1,2,3,4,5,6,7,8,9]
#load Clip
model, preprocess = clip.load(VISUAL_BACKBONE, device ,download_root='/shareddata/clip/')

text_inputs=torch.cat([clip.tokenize(f"a photo of the number \"{c}\"") for c in numbers]).to(device)

In [4]:
accuracy=funct.clip_testing(model,preprocess,clip_test,device,text_inputs)

print(f"the accuracy of Clip on dataset MNIST is {accuracy*100:.2f}%, visual encoder is {VISUAL_BACKBONE}")

100%|██████████| 10000/10000 [06:33<00:00, 25.42it/s]

the accuracy of Clip on dataset Mnist is 84.86%, visual encoder is RN50x64





# Fine-tuning and Evaluating ResNet50

In [5]:
#load ResNet50
resnet50=models.resnet50(pretrained=True)
resnet50.fc=torch.nn.Linear(2048,10) #add a fully connected layer to adjust the output dimension
resnet50=resnet50.to(device)

criterion=nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=1e-2, momentum=0.9)



In [6]:
#fine-tuning
funct.resnet_training(resnet50,criterion,optimizer,train_dataloader,device)
#evaluating
corrects=funct.resnet_testing(resnet50,test_dataloader,device)
accuracy=corrects/len(resnet_test)
print(f"the accuracy of ResNet on MNIST dataset is {accuracy*100:.2f}%")

100%|██████████| 1200/1200 [04:51<00:00,  4.11it/s]
100%|██████████| 667/667 [00:13<00:00, 50.63it/s]

the accuracy of ResNet on Mnist dataset is 99.20%



