In [1]:
import os
import shutil
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/dogvscat

Mounted at /content/drive
/content/drive/MyDrive/dogvscat


In [None]:
!pip install timm

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import shutil
from sklearn.model_selection import train_test_split
import timm  # Importing timm for Vision Transformer models

# Source and target directories
source_dir = './images'
binary_dir = './binary_classification'
multiclass_dir = './multiclass_classification'


# Split data into train and test directories
binary_source_dir = './binary_classification'
binary_train_dir = './train_binary'
binary_test_dir = './test_binary'
multiclass_source_dir = './multiclass_classification'
multiclass_train_dir = './train_multiclass'
multiclass_test_dir = './test_multiclass'

device = torch.device("cuda:0")

# Load a pre-trained Vision Transformer model
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model = model.to(device)


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [5]:

# Freeze all parameters and replace the classifier head
for param in model.parameters():
    param.requires_grad = False
num_features = model.head.in_features
model.head = nn.Linear(num_features, 37).to(device)

# Data transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset and DataLoader
train_dataset = datasets.ImageFolder(root=binary_train_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=binary_test_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Optimizer setup
optimizer = Adam(model.head.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss().to(device)

# Training Function
def train_model(model, criterion, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Evaluation Function
def test_model(model):
    model.eval()
    total = correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

# Training and evaluation
train_model(model, criterion, optimizer)
test_model(model)


Epoch [1/1], Loss: 4.1982
Epoch [1/1], Loss: 4.4555
Epoch [1/1], Loss: 4.1850
Epoch [1/1], Loss: 4.1088
Epoch [1/1], Loss: 3.9691
Epoch [1/1], Loss: 3.7597
Epoch [1/1], Loss: 3.9038
Epoch [1/1], Loss: 4.1430
Epoch [1/1], Loss: 3.6758
Epoch [1/1], Loss: 3.6882
Epoch [1/1], Loss: 3.8500
Epoch [1/1], Loss: 3.7804
Epoch [1/1], Loss: 3.8415
Epoch [1/1], Loss: 3.6088
Epoch [1/1], Loss: 3.5363
Epoch [1/1], Loss: 3.1078
Epoch [1/1], Loss: 3.6158
Epoch [1/1], Loss: 3.5500
Epoch [1/1], Loss: 3.6677
Epoch [1/1], Loss: 3.3121
Epoch [1/1], Loss: 2.9518
Epoch [1/1], Loss: 3.1751
Epoch [1/1], Loss: 3.0420
Epoch [1/1], Loss: 3.0811
Epoch [1/1], Loss: 3.1652
Epoch [1/1], Loss: 2.8551
Epoch [1/1], Loss: 3.0794
Epoch [1/1], Loss: 2.9631
Epoch [1/1], Loss: 3.0492
Epoch [1/1], Loss: 3.0934
Epoch [1/1], Loss: 2.9691
Epoch [1/1], Loss: 2.9861
Epoch [1/1], Loss: 2.6370
Epoch [1/1], Loss: 2.6788
Epoch [1/1], Loss: 2.5723
Epoch [1/1], Loss: 2.6045
Epoch [1/1], Loss: 2.8104
Epoch [1/1], Loss: 2.4457
Epoch [1/1],