# Bharatanatyam Mudra Classification

### Import Libraries

In [None]:
import time
import torch
import torch.backends
import torch.nn as nn
import torchvision.models as models

from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

### Define the Image Transform Object

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match the model's expected input size
    transforms.ToTensor(),          # Convert to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet stats
])

### Load the Dataset

In [None]:
dataset = datasets.ImageFolder(root='mudra_data', transform=transform)

# Split the dataset as before
train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

print(f'Training set: {len(train_dataset)} samples, Validation set: {len(valid_dataset)} samples')

### Define the Model

In [None]:
# Load a pretrained ResNet18 model
model = models.resnet18(pretrained=True)

# Modify the final layer to match the number of classes (51)
num_classes = 51
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)

# Move model to GPU if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on device {device}")
model.to(device)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

### Define the Training Loop

In [None]:
# Update the training loop
def train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=10):
    since = time.time()
    
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            data_loader = train_loader if phase == 'train' else valid_loader
            for inputs, labels in tqdm(data_loader, desc=f"{"Training" if phase == "train" else "Validation"} Epoch {epoch + 1}/{num_epochs}"):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(data_loader.dataset)
            epoch_acc = running_corrects.float() / len(data_loader.dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # Deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), 'mudra_model_resnet18.pth')

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # Load best model weights
    model.load_state_dict(torch.load('mudra_model_resnet18.pth'))

# Train the model
train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=10)