In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader,random_split
from torchvision import transforms,datasets
from transformers import ViTForImageClassification,ViTFeatureExtractor
from tqdm import tqdm

In [3]:
folderPath = "../../../Data/images/"

NUM_CLASSES = 7
BATCH_SIZE = 32
IMG_SIZE = 224
EPOCHS = 5
TRAIN_DIR=folderPath+"train"
DEVICE =torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



In [5]:
train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [6]:
# def get_splits(img_size=(224,224),batch_size=64,val_split=0.1,seed=42,train_dir=folderPath+"train" ):
#     train_ds=image_dataset

full_dataset= datasets.ImageFolder(root=TRAIN_DIR,transform=train_transform)




In [7]:
# Define split ratios (e.g., 70% train, 15% validation, 15% test)
train_size = int(0.9 * len((full_dataset)))
val_size = len(full_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(
    full_dataset, [train_size, val_size]
)

In [8]:
val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset,batch_size= BATCH_SIZE,shuffle=True,num_workers=4)
val_loader = DataLoader(val_dataset,batch_size= BATCH_SIZE,shuffle=False,num_workers=4)

In [9]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=NUM_CLASSES
)

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
optimizer = torch.optim.AdamW(model.parameters(),lr=5e-5)
criterion = nn.CrossEntropyLoss()

In [11]:
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct=0
    total=0
    loop=tqdm(train_loader,desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    for images,labels in loop:
      #  images,labels = images.to(DEVICE) ,labels.to(DEVICE)
       optimizer.zero_grad()
       outputs = model(images).logits
       loss = criterion(outputs,labels)
       loss.backward()
       optimizer.step()

       running_loss += loss.item()
       _,preds = torch.max(outputs,1)
       correct +=(preds ==labels).sum().item()
       total += labels.size(0)
       loop.set_postfix(loss=running_loss/(total/BATCH_SIZE),acc = correct/total)


    model.eval()
    val_correct,val_total = 0,0
    val_loss = 0.0
    with torch.no_grad():
        for images,labels in val_loader:
          #  images,labels = images.to(DEVICE) ,labels.to(DEVICE)
           outputs = model(images).logits
           loss = criterion(outputs,labels)
           val_loss += loss.item()
           _,preds = torch.max(outputs,1)
           val_correct += (preds == labels).sum().item()
           val_total += labels.size(0)
    print(f"Epoch {epoch+1} - Val loss: {val_loss/len(val_loader): .4f}, Val Acc:{val_correct/val_total:.4f}")


torch.save(model.state_dict(),"vit_emotion_pytorch.pth")
print("Model saved")

Epoch 1/5 [Train]: 100%|██████████| 811/811 [1:20:23<00:00,  5.95s/it, acc=0.596, loss=1.1] 


Epoch 1 - Val loss:  0.9201, Val Acc:0.6618


Epoch 2/5 [Train]: 100%|██████████| 811/811 [1:18:24<00:00,  5.80s/it, acc=0.71, loss=0.807] 


Epoch 2 - Val loss:  0.8694, Val Acc:0.6781


Epoch 3/5 [Train]: 100%|██████████| 811/811 [1:17:46<00:00,  5.75s/it, acc=0.791, loss=0.601]


Epoch 3 - Val loss:  0.8760, Val Acc:0.6937


Epoch 4/5 [Train]: 100%|██████████| 811/811 [1:15:16<00:00,  5.57s/it, acc=0.87, loss=0.396] 


Epoch 4 - Val loss:  0.9933, Val Acc:0.6889


Epoch 5/5 [Train]: 100%|██████████| 811/811 [1:17:58<00:00,  5.77s/it, acc=0.925, loss=0.243]


Epoch 5 - Val loss:  1.0595, Val Acc:0.6916
Model saved
