In [1]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

2024-07-22 13:20:12.258634: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-22 13:20:12.258696: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-22 13:20:12.260179: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
class FER2013Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
        self.images, self.labels = self._load_dataset()

    def _load_dataset(self):
        images = []
        labels = []
        for class_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(self.root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                images.append(img_path)
                labels.append(class_idx)
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # CLIP expects 224x224
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel grayscale
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = FER2013Dataset('/kaggle/input/fer2013/train',
                               transform=transform)
test_dataset = FER2013Dataset('/kaggle/input/fer2013/test',
                              transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,)

In [4]:
class FERModel(nn.Module):
    def __init__(self, num_classes, device='cuda'):
        super(FERModel, self).__init__()
        self.clip_model = (CLIPModel.from_pretrained("openai/clip-vit-base-patch32").vision_model).to(device)
        self.classifier = (nn.Linear(768, num_classes)).to(device) # CLIP's output dimension is 768
        # Freeze CLIP parameters
        for param in self.clip_model.parameters():
            param.requires_grad = False

    def forward(self, x):
        with torch.no_grad():
            features = self.clip_model(x).last_hidden_state[:, 0, :]  # Use CLS token
        output = self.classifier(features)
        return output


In [5]:
def train(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}")
    for batch_idx, (images, labels) in progress_bar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        progress_bar.set_postfix({
            'Loss': f"{running_loss/(batch_idx+1):.3f}",
            'Acc': f"{100.*correct/total:.2f}%"
        })    
    return running_loss / len(train_loader), 100. * correct / total

In [6]:
def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    average_loss = total_loss / len(test_loader)
    return average_loss, accuracy

In [7]:
model = FERModel(num_classes=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

  return self.fget.__get__(instance, owner)()


In [8]:

train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
best_accuracy = 0.0

num_epochs = 20
for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, epoch)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    print()


Epoch 1: 100%|██████████| 898/898 [02:23<00:00,  6.24it/s, Loss=1.124, Acc=59.40%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.26it/s]


Epoch 1/20:
Train Loss: 1.1243, Train Acc: 59.40%
Test Loss: 1.0103, Test Acc: 61.86%



Epoch 2: 100%|██████████| 898/898 [02:22<00:00,  6.32it/s, Loss=0.955, Acc=64.83%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.35it/s]


Epoch 2/20:
Train Loss: 0.9550, Train Acc: 64.83%
Test Loss: 0.9688, Test Acc: 62.76%



Epoch 3: 100%|██████████| 898/898 [02:21<00:00,  6.34it/s, Loss=0.921, Acc=65.53%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.39it/s]


Epoch 3/20:
Train Loss: 0.9211, Train Acc: 65.53%
Test Loss: 0.9463, Test Acc: 64.07%



Epoch 4: 100%|██████████| 898/898 [02:22<00:00,  6.28it/s, Loss=0.902, Acc=66.25%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.34it/s]


Epoch 4/20:
Train Loss: 0.9018, Train Acc: 66.25%
Test Loss: 0.9337, Test Acc: 64.35%



Epoch 5: 100%|██████████| 898/898 [02:21<00:00,  6.36it/s, Loss=0.889, Acc=66.76%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.38it/s]


Epoch 5/20:
Train Loss: 0.8888, Train Acc: 66.76%
Test Loss: 0.9352, Test Acc: 64.82%



Epoch 6: 100%|██████████| 898/898 [02:21<00:00,  6.36it/s, Loss=0.880, Acc=67.28%]
Evaluating: 100%|██████████| 225/225 [00:34<00:00,  6.43it/s]


Epoch 6/20:
Train Loss: 0.8796, Train Acc: 67.28%
Test Loss: 0.9266, Test Acc: 64.85%



Epoch 7: 100%|██████████| 898/898 [02:21<00:00,  6.35it/s, Loss=0.874, Acc=67.52%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.30it/s]


Epoch 7/20:
Train Loss: 0.8738, Train Acc: 67.52%
Test Loss: 0.9184, Test Acc: 65.24%



Epoch 8: 100%|██████████| 898/898 [02:22<00:00,  6.32it/s, Loss=0.867, Acc=67.63%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.40it/s]


Epoch 8/20:
Train Loss: 0.8671, Train Acc: 67.63%
Test Loss: 0.9168, Test Acc: 64.68%



Epoch 9: 100%|██████████| 898/898 [02:24<00:00,  6.22it/s, Loss=0.862, Acc=67.88%]
Evaluating: 100%|██████████| 225/225 [00:36<00:00,  6.11it/s]


Epoch 9/20:
Train Loss: 0.8621, Train Acc: 67.88%
Test Loss: 0.9143, Test Acc: 64.96%



Epoch 10: 100%|██████████| 898/898 [02:23<00:00,  6.27it/s, Loss=0.857, Acc=68.24%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.38it/s]


Epoch 10/20:
Train Loss: 0.8569, Train Acc: 68.24%
Test Loss: 0.9145, Test Acc: 65.42%



Epoch 11: 100%|██████████| 898/898 [02:22<00:00,  6.29it/s, Loss=0.852, Acc=68.16%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.35it/s]


Epoch 11/20:
Train Loss: 0.8524, Train Acc: 68.16%
Test Loss: 0.9126, Test Acc: 65.45%



Epoch 12: 100%|██████████| 898/898 [02:23<00:00,  6.27it/s, Loss=0.849, Acc=68.39%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.43it/s]


Epoch 12/20:
Train Loss: 0.8490, Train Acc: 68.39%
Test Loss: 0.9070, Test Acc: 65.66%



Epoch 13: 100%|██████████| 898/898 [02:22<00:00,  6.32it/s, Loss=0.846, Acc=68.68%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.27it/s]


Epoch 13/20:
Train Loss: 0.8460, Train Acc: 68.68%
Test Loss: 0.9076, Test Acc: 65.69%



Epoch 14: 100%|██████████| 898/898 [02:22<00:00,  6.28it/s, Loss=0.843, Acc=68.73%]
Evaluating: 100%|██████████| 225/225 [00:34<00:00,  6.44it/s]


Epoch 14/20:
Train Loss: 0.8428, Train Acc: 68.73%
Test Loss: 0.9085, Test Acc: 65.14%



Epoch 15: 100%|██████████| 898/898 [02:25<00:00,  6.18it/s, Loss=0.839, Acc=68.78%]
Evaluating: 100%|██████████| 225/225 [00:34<00:00,  6.48it/s]


Epoch 15/20:
Train Loss: 0.8394, Train Acc: 68.78%
Test Loss: 0.9114, Test Acc: 65.63%



Epoch 16: 100%|██████████| 898/898 [02:20<00:00,  6.37it/s, Loss=0.838, Acc=68.91%]
Evaluating: 100%|██████████| 225/225 [00:34<00:00,  6.49it/s]


Epoch 16/20:
Train Loss: 0.8377, Train Acc: 68.91%
Test Loss: 0.9071, Test Acc: 65.10%



Epoch 17: 100%|██████████| 898/898 [02:22<00:00,  6.32it/s, Loss=0.834, Acc=68.95%]
Evaluating: 100%|██████████| 225/225 [00:35<00:00,  6.34it/s]


Epoch 17/20:
Train Loss: 0.8344, Train Acc: 68.95%
Test Loss: 0.9187, Test Acc: 65.10%



Epoch 18: 100%|██████████| 898/898 [02:22<00:00,  6.32it/s, Loss=0.833, Acc=68.93%]
Evaluating: 100%|██████████| 225/225 [00:36<00:00,  6.23it/s]


Epoch 18/20:
Train Loss: 0.8333, Train Acc: 68.93%
Test Loss: 0.9021, Test Acc: 65.90%



Epoch 19: 100%|██████████| 898/898 [02:24<00:00,  6.21it/s, Loss=0.831, Acc=68.98%]
Evaluating: 100%|██████████| 225/225 [00:34<00:00,  6.51it/s]


Epoch 19/20:
Train Loss: 0.8307, Train Acc: 68.98%
Test Loss: 0.9061, Test Acc: 65.55%



Epoch 20: 100%|██████████| 898/898 [02:20<00:00,  6.40it/s, Loss=0.829, Acc=69.37%]
Evaluating: 100%|██████████| 225/225 [00:34<00:00,  6.50it/s]

Epoch 20/20:
Train Loss: 0.8291, Train Acc: 69.37%
Test Loss: 0.9012, Test Acc: 65.83%






In [10]:
torch.save(model, '/kaggle/working/model.pth')

In [None]:
model = torch.load('model.pth')
