In [1]:
import torch, torchvision
from torchvision.transforms import ToTensor, Resize
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim
from tqdm import tqdm
import copy
import numpy as np
from torch.utils.data import Dataset
import requests
from sklearn.model_selection import train_test_split

In [2]:
TRAIN_PATH="./data/train"
VAL_PATH='./data/validationdp'
NUM_BATCH=32
EPOCHS=3
LEARNING_RATE=1e-3
DEVICE="cuda:0" 
print(DEVICE)

cuda:0


In [3]:
from pytorch_pretrained_vit import ViT
model = ViT("B_16_imagenet1k", pretrained=True)
#freezing without fc
for param in model.parameters():
    param.requires_grad=False

model.fc=nn.Sequential(*[
    nn.Linear(in_features=768, out_features=2),
    #nn.Softmax(dim=1)
])

Loaded pretrained weights.


In [4]:
transform=transforms.Compose([
    ToTensor(),
    Resize((384,384))
])

In [5]:
class Dog_Cat_Dataset(Dataset):
    def __init__(self, dir, transform=None):
        self.dir=dir
        self.transform=transform
        self.images=os.listdir(dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_path=os.path.join(self.dir,self.images[index])
        label=self.images[index].split(".")[0]
        label=0 if label=='dog' else 1
        image=np.array(Image.open(image_path))
        image=self.transform(image)
        return image,label

In [6]:
train_data=Dog_Cat_Dataset(TRAIN_PATH,transform)
train_dataloader=DataLoader(train_data,batch_size=NUM_BATCH)
val_data=Dog_Cat_Dataset(VAL_PATH,transform)
val_dataloader=DataLoader(val_data,batch_size=NUM_BATCH)

In [7]:
def validate(model, data):
    total =0
    correct=0

    for(images, labels) in data:
        images=images.to(DEVICE)
        labels=labels.to(DEVICE)
        x=model(images)
        pred=torch.argmax(x,1)
        total += x.size(0)
        correct += torch.sum(pred==labels)
        
    return correct*100/total

In [8]:
def train(num_epoch=EPOCHS, lr=LEARNING_RATE, device=DEVICE):
    cnn=model.to(device)
    cel=nn.CrossEntropyLoss()
    optimizer=optim.Adam(cnn.parameters(),lr=lr,weight_decay=0.5)

    max_accuracy=0

    for epoch in range(num_epoch):
        for i, (images,labels) in tqdm(enumerate(train_dataloader)):
            images=images.to(device)
            labels=labels.to(device)
            optimizer.zero_grad()
            pred=cnn(images)
            loss=cel(pred,labels)
            loss.backward()
            optimizer.step()
            if i%700==0 and i!=0:
                print("loss is",loss.item())
        
        accuracy=float(validate(cnn,val_dataloader))
        if accuracy>max_accuracy:
            best_model=copy.deepcopy(cnn)
            max_accuracy=accuracy
            print("find best!")
        print("Epoch: ",epoch+1,"Accuracy: ",accuracy,"%")
    
    return best_model

In [9]:
ResNet=train()

702it [04:58,  2.33it/s]

loss is 0.020126352086663246


719it [05:06,  2.35it/s]


find best!
Epoch:  1 Accuracy:  91.32500457763672 %


702it [05:02,  2.33it/s]

loss is 0.02400299347937107


719it [05:09,  2.32it/s]


find best!
Epoch:  2 Accuracy:  91.45000457763672 %


702it [05:03,  2.32it/s]

loss is 0.031747397035360336


719it [05:10,  2.32it/s]


find best!
Epoch:  3 Accuracy:  92.0250015258789 %


In [None]:
best_cnn=ResNet.to("cpu")
torch.save(ResNet, 'ViT_new.pth')