In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import timm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Paths to your dataset
train_dir = 'Dataset/train'
val_dir = 'Dataset/test'

In [3]:
# Load the dataset using ImageFolder
train_dataset = datasets.ImageFolder(train_dir)
val_dataset = datasets.ImageFolder(val_dir)

In [4]:
# Function to count images in each class
def count_images_per_class(dataset):
    class_counts = {class_name: 0 for class_name in dataset.classes}
    for path, label in dataset.imgs:
        class_name = dataset.classes[label]
        class_counts[class_name] += 1
    return class_counts

In [5]:
# Count images in train and validation datasets
train_counts = count_images_per_class(train_dataset)
val_counts = count_images_per_class(val_dataset)

In [6]:
# Combine counts
total_counts = {class_name: train_counts.get(class_name, 0) + val_counts.get(class_name, 0) 
                for class_name in train_dataset.classes}

# Calculate total images
total_train_images = sum(train_counts.values())
total_val_images = sum(val_counts.values())
total_images = total_train_images + total_val_images

In [7]:
print("Training Dataset:")
print(f"Total Classes: {len(train_dataset.classes)}")
print(f"Total Images: {total_train_images}")
print(f"Class Distribution: {train_counts}")

Training Dataset:
Total Classes: 8
Total Images: 28386
Class Distribution: {'anger': 2466, 'contempt': 165, 'disgust': 191, 'fear': 652, 'happiness': 7528, 'neutral': 10308, 'sadness': 3514, 'surprise': 3562}


In [8]:
print("\nValidation Dataset:")
print(f"Total Classes: {len(val_dataset.classes)}")
print(f"Total Images: {total_val_images}")
print(f"Class Distribution: {val_counts}")


Validation Dataset:
Total Classes: 8
Total Images: 7099
Class Distribution: {'anger': 644, 'contempt': 51, 'disgust': 57, 'fear': 167, 'happiness': 1827, 'neutral': 2597, 'sadness': 856, 'surprise': 900}


In [9]:
print("\nOverall Dataset:")
print(f"Total Classes: {len(train_dataset.classes)}")  
print(f"Total Images: {total_images}")
print(f"Class Distribution: {total_counts}")


Overall Dataset:
Total Classes: 8
Total Images: 35485
Class Distribution: {'anger': 3110, 'contempt': 216, 'disgust': 248, 'fear': 819, 'happiness': 9355, 'neutral': 12905, 'sadness': 4370, 'surprise': 4462}


In [10]:
# Data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [11]:
# Datasets and Dataloaders
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [13]:
# Define the model
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=8)
model = model.to('cpu')

Error while downloading from https://cdn-lfs.huggingface.co/repos/fb/cd/fbcdc88e492959e3ee2515f497fb45cb5f217f9455c204bdf4e0b400c90d0c23/32aa17d6e17b43500f531d5f6dc9bc93e56ed8841b8a75682e1bb295d722405b?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1723374817&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMzM3NDgxN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9mYi9jZC9mYmNkYzg4ZTQ5Mjk1OWUzZWUyNTE1ZjQ5N2ZiNDVjYjVmMjE3Zjk0NTVjMjA0YmRmNGUwYjQwMGM5MGQwYzIzLzMyYWExN2Q2ZTE3YjQzNTAwZjUzMWQ1ZjZkYzliYzkzZTU2ZWQ4ODQxYjhhNzU2ODJlMWJiMjk1ZDcyMjQwNWI%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=AswL%7EjF1gEzs5yPNa105A4B2ltj6QuGkbXizJ6TnK3vLFsPkwirLo8kwVXlecufXVvWZnB1Cyq8Q9-XsuGlNs4Ty2DOwyhGackv8Gr-caL3VIe8FhmHS17sIxrcqw60QZ7uPvfvStY0DSIFYO1w3brlbQoxJ0AQ3MQgPNHWi0Qf468o7VVlu8%7EgJrv6TrCDv0COi1Ld4xSZcsMM-yT5sF5Imn7wEhK-Vf2BO6HCmiHbN3JRLFL

ConnectionError: (MaxRetryError('HTTPSConnectionPool(host=\'cdn-lfs.huggingface.co\', port=443): Max retries exceeded with url: /repos/fb/cd/fbcdc88e492959e3ee2515f497fb45cb5f217f9455c204bdf4e0b400c90d0c23/32aa17d6e17b43500f531d5f6dc9bc93e56ed8841b8a75682e1bb295d722405b?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1723374817&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMzM3NDgxN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9mYi9jZC9mYmNkYzg4ZTQ5Mjk1OWUzZWUyNTE1ZjQ5N2ZiNDVjYjVmMjE3Zjk0NTVjMjA0YmRmNGUwYjQwMGM5MGQwYzIzLzMyYWExN2Q2ZTE3YjQzNTAwZjUzMWQ1ZjZkYzliYzkzZTU2ZWQ4ODQxYjhhNzU2ODJlMWJiMjk1ZDcyMjQwNWI~cmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=AswL~jF1gEzs5yPNa105A4B2ltj6QuGkbXizJ6TnK3vLFsPkwirLo8kwVXlecufXVvWZnB1Cyq8Q9-XsuGlNs4Ty2DOwyhGackv8Gr-caL3VIe8FhmHS17sIxrcqw60QZ7uPvfvStY0DSIFYO1w3brlbQoxJ0AQ3MQgPNHWi0Qf468o7VVlu8~gJrv6TrCDv0COi1Ld4xSZcsMM-yT5sF5Imn7wEhK-Vf2BO6HCmiHbN3JRLFLZIymq2bS1~OkYKUqgOZud5RYJtoJbKbyIoJ9tQG2HWleNXVbO35eX4NMZx-9EAb~xk1nuUTglYlpXz-n~1U2ws4smh14I~bG-~rQ__&Key-Pair-Id=K3ESJI6DHPFC7 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001E7B4671460>: Failed to resolve \'cdn-lfs.huggingface.co\' ([Errno 11001] getaddrinfo failed)"))'), '(Request ID: 7d0fc332-821a-4c18-a579-8f90e6af0f86)')

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=1):
    device = 'cpu'  # Ensure the model runs on CPU
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total
        
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

In [None]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=1)