In [1]:
from model import MultiLoRAViT
from dataloader import AnimeCharacterDataset
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn

# Create dataset
dataset = AnimeCharacterDataset(
    csv_path='train_data.csv',
    img_dir='single_characters'
)
test_dataset = AnimeCharacterDataset(
    csv_path='test_data.csv',
    img_dir='single_characters'
)

import json


with open('encodings.json', 'r') as f:
    encodings = json.load(f)

adapter_config = {
    'Age': len(encodings['Age']),
    'Gender': len(encodings['Gender']),
    'Ethnicity': len(encodings['Ethnicity']),
    'Hair Style': len(encodings['Hair Style']),
    'Hair Color': len(encodings['Hair Color']),
    'Hair Length': len(encodings['Hair Length']),
    'Eye Color': len(encodings['Eye Color']),
    'Body Type': len(encodings['Body Type']),
    'Dress': len(encodings['Dress'])
}


model = MultiLoRAViT(adapter_config, r=4)
model.switch_adapter('Age')  # Switch to Age adapter

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)

# Training configuration
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

optimizers = {
    name: optim.Adam(
        list(model.heads[name].parameters()) + 
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-3
    ) for name in adapter_config.keys()
}
criterion = nn.CrossEntropyLoss()

  return disable_fn(*args, **kwargs)


In [2]:
best_val_accuracies = {name: 0.0 for name in adapter_config.keys()}

for epoch in range(num_epochs):
    model.train()
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        
        # Train each adapter separately
        for adapter_idx, (adapter_name, num_classes) in enumerate(adapter_config.items()):
            model.switch_adapter(adapter_name)
            optimizer = optimizers[adapter_name]
            
            # Get corresponding label for this adapter
            label = labels[:, adapter_idx].long().to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, label)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch}, Adapter: {adapter_name}, Batch: {batch_idx}, Loss: {loss.item():.4f}')

    
    # Validation
    model.eval()
    with torch.no_grad():
        val_losses = {name: 0.0 for name in adapter_config.keys()}
        val_correct = {name: 0 for name in adapter_config.keys()}
        val_total = 0
        
        for images, labels in val_loader:
            images = images.to(device)
            val_total += images.size(0)
            
            for adapter_idx, (adapter_name, num_classes) in enumerate(adapter_config.items()):
                model.switch_adapter(adapter_name)
                label = labels[:, adapter_idx].long().to(device)
                
                outputs = model(images)
                loss = criterion(outputs, label)
                val_losses[adapter_name] += loss.item()
                
                _, predicted = outputs.max(1)
                val_correct[adapter_name] += predicted.eq(label).sum().item()
        
        # Print validation results
        print(f'\nEpoch {epoch} Validation Results:')
        for adapter_name in adapter_config.keys():
            acc = 100. * val_correct[adapter_name] / val_total
            avg_loss = val_losses[adapter_name] / len(val_loader)
            print(f'{adapter_name}: Accuracy: {acc:.2f}%, Avg Loss: {avg_loss:.4f}')
            
            # Save the best model for each adapter based on accuracy
            if acc > best_val_accuracies[adapter_name]:
                best_val_accuracies[adapter_name] = acc
                print(f'Saving model for adapter: {adapter_name}')
                model.save_model(epoch=epoch, exp=f'checkpoints', adapter_name=adapter_name)
                
        print()

Epoch: 0, Adapter: Age, Batch: 0, Loss: 2.4940
Epoch: 0, Adapter: Gender, Batch: 0, Loss: 1.8274
Epoch: 0, Adapter: Ethnicity, Batch: 0, Loss: 2.3091
Epoch: 0, Adapter: Hair Style, Batch: 0, Loss: 3.7287
Epoch: 0, Adapter: Hair Color, Batch: 0, Loss: 3.5526
Epoch: 0, Adapter: Hair Length, Batch: 0, Loss: 1.5644
Epoch: 0, Adapter: Eye Color, Batch: 0, Loss: 3.5098
Epoch: 0, Adapter: Body Type, Batch: 0, Loss: 3.4425
Epoch: 0, Adapter: Dress, Batch: 0, Loss: 4.5570
Epoch: 0, Adapter: Age, Batch: 10, Loss: 1.7198
Epoch: 0, Adapter: Gender, Batch: 10, Loss: 0.9180
Epoch: 0, Adapter: Ethnicity, Batch: 10, Loss: 1.2269
Epoch: 0, Adapter: Hair Style, Batch: 10, Loss: 2.0251
Epoch: 0, Adapter: Hair Color, Batch: 10, Loss: 2.0979
Epoch: 0, Adapter: Hair Length, Batch: 10, Loss: 0.7122
Epoch: 0, Adapter: Eye Color, Batch: 10, Loss: 2.2635
Epoch: 0, Adapter: Body Type, Batch: 10, Loss: 1.9597
Epoch: 0, Adapter: Dress, Batch: 10, Loss: 3.0974
Epoch: 0, Adapter: Age, Batch: 20, Loss: 1.6025
Epoch: 

In [4]:
from model import MultiLoRAViT
from dataloader import AnimeCharacterDataset
import torch
from torch.utils.data import DataLoader
import json

with open('encodings.json', 'r') as f:
    encodings = json.load(f)

adapter_config = {
    'Age': len(encodings['Age']),
    'Gender': len(encodings['Gender']),
    'Ethnicity': len(encodings['Ethnicity']),
    'Hair Style': len(encodings['Hair Style']),
    'Hair Color': len(encodings['Hair Color']),
    'Hair Length': len(encodings['Hair Length']),
    'Eye Color': len(encodings['Eye Color']),
    'Body Type': len(encodings['Body Type']),
    'Dress': len(encodings['Dress'])
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiLoRAViT(adapter_config, r=4)

epochs = model.load_model(exp='checkpoints')
model = model.to(device)
print("Loaded checkpoints from epochs:", 
      {name: epoch for name, epoch in zip(adapter_config.keys(), epochs)})

test_dataset = AnimeCharacterDataset(
    csv_path='test_data.csv',
    img_dir='single_characters'
)
test_loader = DataLoader(test_dataset, batch_size=32)

model.eval()
results = {name: {'correct': 0, 'total': 0} for name in adapter_config.keys()}

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        
        for adapter_idx, (adapter_name, num_classes) in enumerate(adapter_config.items()):
            model.switch_adapter(adapter_name)
            label = labels[:, adapter_idx].long().to(device)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            results[adapter_name]['correct'] += predicted.eq(label).sum().item()
            results[adapter_name]['total'] += labels.size(0)

print("\nTest Results:")
for adapter_name, metrics in results.items():
    accuracy = 100. * metrics['correct'] / metrics['total']
    print(f"{adapter_name}: Test Accuracy = {accuracy:.2f}%")


Loaded checkpoints from epochs: {'Age': 4, 'Gender': 6, 'Ethnicity': 9, 'Hair Style': 5, 'Hair Color': 8, 'Hair Length': 5, 'Eye Color': 4, 'Body Type': 1, 'Dress': 6}

Test Results:
Age: Test Accuracy = 56.50%
Gender: Test Accuracy = 95.53%
Ethnicity: Test Accuracy = 57.32%
Hair Style: Test Accuracy = 50.41%
Hair Color: Test Accuracy = 73.17%
Hair Length: Test Accuracy = 72.76%
Eye Color: Test Accuracy = 51.63%
Body Type: Test Accuracy = 43.50%
Dress: Test Accuracy = 37.80%


### Testing

In [11]:
from model import MultiLoRAViT
import torch
from PIL import Image
from torchvision import transforms
import json

with open('encodings.json', 'r') as f:
    encodings = json.load(f)

adapter_config = {
    'Age': len(encodings['Age']),
    'Gender': len(encodings['Gender']),
    'Ethnicity': len(encodings['Ethnicity']),
    'Hair Style': len(encodings['Hair Style']),
    'Hair Color': len(encodings['Hair Color']),
    'Hair Length': len(encodings['Hair Length']),
    'Eye Color': len(encodings['Eye Color']),
    'Body Type': len(encodings['Body Type']),
    'Dress': len(encodings['Dress'])
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiLoRAViT(adapter_config, r=4)

epochs = model.load_model(exp='checkpoints')
model = model.to(device)
print("Loaded checkpoints from epochs:", 
      {name: epoch for name, epoch in zip(adapter_config.keys(), epochs)})

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def predict_attributes(image_path):
    try:
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to(device)
        
        model.eval()
        predictions = {}
        
        with torch.no_grad():
            for adapter_name in encodings.keys():
                model.switch_adapter(adapter_name)
                output = model(image_tensor)
                pred_idx = output.argmax(1).item()
                
                reverse_encoding = {v: k for k, v in encodings[adapter_name].items()}
                predictions[adapter_name] = reverse_encoding[pred_idx]

        return predictions
    
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

Loaded checkpoints from epochs: {'Age': 4, 'Gender': 6, 'Ethnicity': 9, 'Hair Style': 5, 'Hair Color': 8, 'Hair Length': 5, 'Eye Color': 4, 'Body Type': 1, 'Dress': 6}


In [13]:
# image_path = "/raid/biplab/sarthak/dashtoon/single_characters/danbooru_398821_525b329a700db7a59c3c234a7fb91655.jpg"
import os
for image in os.listdir('test_images'):
    predictions = predict_attributes(f'test_images/{image}')
    print(f"Image: {image}, Predicted attributes: {predictions}")

Image: danbooru_431031_d94bbf167e4067caab3c36804f079f4d.jpg, Predicted attributes: {'Age': 'Young', 'Body Type': 'Unknown', 'Dress': 'Collared Shirt', 'Ethnicity': 'Japanese', 'Eye Color': 'Grey', 'Gender': 'Male', 'Hair Color': 'Black', 'Hair Length': 'Short', 'Hair Style': 'Short'}
Image: danbooru_469064_98d34d4e476870f70b899274dc8e6547.jpg, Predicted attributes: {'Age': 'Early', 'Body Type': 'Normal', 'Dress': 'School Uniform', 'Ethnicity': 'Japanese', 'Eye Color': 'Brown', 'Gender': 'Female', 'Hair Color': 'Black', 'Hair Length': 'Long', 'Hair Style': 'Long'}
Image: danbooru_469135_0355266341d00f6fed83a6081942ecad.jpg, Predicted attributes: {'Age': 'Young', 'Body Type': 'Slim', 'Dress': 'Maid', 'Ethnicity': 'Asian', 'Eye Color': 'Green', 'Gender': 'Male', 'Hair Color': 'Black', 'Hair Length': 'Long', 'Hair Style': 'Straight'}
Image: danbooru_469090_9584201f870356e0bc04a44693408a1a.jpg, Predicted attributes: {'Age': 'Teen', 'Body Type': 'Slim', 'Dress': 'Serafuku', 'Ethnicity': 'Jap