# Basic Import

In [1]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
import tqdm
import clip
import torch
import torch.nn as nn
import torch.optim as optim
import funct

In [2]:
EPOCH=3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset

In [3]:
transform=transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor()
])
# datasets for fine-tuning and evaluating resnet50
resnet_train = datasets.CIFAR100(root='./data',train=True,download=True,transform=transform)
resnet_test=datasets.CIFAR100(root='./data',train=False,download=True,transform=transform)
#dataset for evaluating Clip
clip_test=datasets.CIFAR100(root='./data',train=False,download=True)

train_dataloader=torch.utils.data.DataLoader(resnet_train,batch_size=20,shuffle=True,num_workers=2)
test_dataloader=torch.utils.data.DataLoader(resnet_test,batch_size=20,shuffle=False,num_workers=2)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


# Evaluating CLIP

In [4]:
VISUAL_BACKBONE="ViT-L/14"
names=resnet_train.classes

model, preprocess = clip.load(VISUAL_BACKBONE, device ,download_root='/shareddata/clip/')

text_inputs=torch.cat([clip.tokenize(f"a photo of {c}") for c in resnet_train.classes]).to(device)

In [5]:
accuracy=funct.clip_testing(model,preprocess,clip_test,device,text_inputs)

print(f"the accuracy of Clip on CIFAR100 dataset is {accuracy*100:.2f}%, visual encoder is {VISUAL_BACKBONE}")

100%|██████████| 10000/10000 [09:49<00:00, 16.96it/s]

the accuracy of Clip on CIFAR100 dataset is 65.27%, visual encoder is ViT-L/14





# Fine-tuning and evaluating ResNet50

In [6]:
resnet50=models.resnet50(pretrained=True)
resnet50.fc=torch.nn.Linear(2048,100) #add a fully connected layer to adjust the output dimension
resnet50=resnet50.to(device)

criterion=nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=1e-2, momentum=0.9)



In [7]:
for i in range(EPOCH): # just training for 3 epochs. The accuracy may not converge
    funct.resnet_training(resnet50,criterion,optimizer,train_dataloader,device)
    corrects=funct.resnet_testing(resnet50,test_dataloader,device)
    accuracy=corrects/len(resnet_test)
    print(f"the accuracy of ResNet on CIFAR100 dataset is {accuracy*100:.2f}%, the training epoch is{i+1}")

100%|██████████| 2500/2500 [04:49<00:00,  8.62it/s]
100%|██████████| 500/500 [00:12<00:00, 40.11it/s]


the accuracy of ResNet on CIFAR100 dataset is 59.00%, the training epoch is1


100%|██████████| 2500/2500 [04:44<00:00,  8.79it/s]
100%|██████████| 500/500 [00:12<00:00, 39.58it/s]


the accuracy of ResNet on CIFAR100 dataset is 66.09%, the training epoch is2


100%|██████████| 2500/2500 [05:01<00:00,  8.30it/s]
100%|██████████| 500/500 [00:13<00:00, 37.35it/s]

the accuracy of ResNet on CIFAR100 dataset is 70.55%, the training epoch is3



