In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split


## Dataset Class

In [2]:
class MagicCardIDDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = os.listdir(image_dir)
        self.transform = transform
        
        # Encode the unique IDs as numeric labels for classification
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform([os.path.splitext(img)[0] for img in self.image_paths])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_paths[idx])
        image = Image.open(img_name).convert("RGB")
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)

        return image, label


## Transforms

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),       # Resize to 224x224 (compatible with many pre-trained models)
    transforms.RandomHorizontalFlip(),   # Add random horizontal flip
    transforms.RandomRotation(10),       # Random rotation up to 10 degrees
    transforms.ToTensor(),               # Convert PIL image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize for pretrained models
])


## Splitting the Dataset

In [4]:
# Path to your dataset
image_dir = 'datasets/tcg_magic/training/'

# Create dataset
full_dataset = MagicCardIDDataset(image_dir=image_dir, transform=transform)

# Split dataset into train and validation sets (80% train, 20% val)
train_idx, val_idx = train_test_split(list(range(len(full_dataset))), test_size=0.2, random_state=42)

train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
val_dataset = torch.utils.data.Subset(full_dataset, val_idx)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

## Define the Model

In [None]:
# Load a pre-trained ResNet18 model
model = models.resnet18(pretrained=True)

# Modify the final layer to match the number of unique cards (classes)
num_classes = len(full_dataset.label_encoder.classes_)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

## Training Setup (Loss, Optimizer)

In [6]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Training Loop

In [7]:
# Function for training one epoch
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

# Function for validation
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            
            # Get predicted class
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    accuracy = correct / len(dataloader.dataset)
    return epoch_loss, accuracy

## Train and Validate the Model

In [None]:
num_epochs = 20  # Number of epochs to train for

# Learning rate scheduler to reduce learning rate over time
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)


best_accuracy = 0.0

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)
    
    # Step the learning rate scheduler
    scheduler.step()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
    
    # Save the best model based on validation accuracy
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), 'models/tcg_magic/best_model.pth')


## Saving the Model

In [9]:
torch.save(model.state_dict(), 'models/tcg_magic/model.pth')
