In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import zipfile
import json
import base64
import io
from PIL import Image
import numpy as np
from tqdm import tqdm
from mnist_skeptic_v9 import skeptic_v9
import os
import torch.nn.functional as F

In [3]:
class CompositeDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def base64_to_image(self, base64_string):
        image_data = base64.b64decode(base64_string)
        image = Image.open(io.BytesIO(image_data)).convert('L')
        image_array = np.array(image, dtype=np.float32) / 255.0
        return torch.from_numpy(image_array).unsqueeze(0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = self.base64_to_image(item['composite'])
        label = torch.tensor(int(item['true_digit']), dtype=torch.long)
        return image, label

class SelectionDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = [item for sublist in data for item in sublist]  # Flatten the list of lists

    def base64_to_image(self, base64_string):
        image_data = base64.b64decode(base64_string)
        image = Image.open(io.BytesIO(image_data)).convert('L')
        image_array = np.array(image, dtype=np.float32) / 255.0
        return torch.from_numpy(image_array).unsqueeze(0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = self.base64_to_image(item['selected_image'])
        label = torch.tensor(int(item['true_digit']), dtype=torch.long)
        return image, label

def load_data(test_file_path, train_file_path):
    with open(test_file_path, 'r') as f:
        test_data = json.load(f)
    
    with open(train_file_path, 'r') as f:
        train_data = json.load(f)
    
    return test_data, train_data

In [4]:
class EnsembleModel(nn.Module):
    def __init__(self, model_paths):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList([skeptic_v9() for _ in range(len(model_paths))])
        for model, path in zip(self.models, model_paths):
            model.load_state_dict(torch.load(path))
            model.eval()

    def forward(self, x):
        outputs = [model(x) for model in self.models]
        return torch.stack(outputs).mean(dim=0)

def create_ensemble(model_dir='best_boi_models'):
    model_paths = [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pth')]
    return EnsembleModel(model_paths)

In [5]:
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, weight_8=0.5):
        super().__init__()
        self.weight_8 = weight_8

    def forward(self, outputs, targets):
        ce_loss = F.cross_entropy(outputs, targets, reduction='none')
        weights = torch.ones_like(targets, dtype=torch.float)
        weights[targets != 8] = 1 / (1 - self.weight_8)
        weights[targets == 8] = self.weight_8
        return (ce_loss * weights).mean()

In [6]:
def fine_tune_model(model, train_loader, val_loader, num_epochs=5, learning_rate=0.0001, weight_8=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = WeightedCrossEntropyLoss(weight_8=weight_8)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
    
    best_val_accuracy = 0
    best_model_state = None
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        class_correct = [0] * 10
        class_total = [0] * 10
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            for i in range(10):
                class_correct[i] += ((predicted == i) & (labels == i)).sum().item()
                class_total[i] += (labels == i).sum().item()
            
            accuracy = 100 * correct / total
            progress_bar.set_postfix({'Loss': running_loss/len(progress_bar), 'Accuracy': f'{accuracy:.2f}%'})
        
        # Print class-wise accuracies
        for i in range(10):
            class_acc = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
            print(f'Accuracy of {i}: {class_acc:.2f}%')
        
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_accuracy = 100 * val_correct / val_total
        print(f'Epoch {epoch+1} Validation Accuracy: {val_accuracy:.2f}%')
        
        scheduler.step(val_accuracy)
        
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_state = model.state_dict()
    
    # Load the best model state
    model.load_state_dict(best_model_state)
    return model

In [33]:
def main():
    # Load the pre-trained ensemble model
    ensemble_model = create_ensemble()
    
    # Load data
    test_data, train_data = load_data('training_data/test_set/participant_0.json', 'training_data/training_set/participant_0.json')
    
    # Create datasets
    train_composite_dataset = CompositeDataset(test_data)  # Using test_data for composite images
    train_selection_dataset = SelectionDataset(train_data)
    val_composite_dataset = CompositeDataset(test_data)
    
    # Create data loaders
    train_composite_loader = DataLoader(train_composite_dataset, batch_size=64, shuffle=True)
    train_selection_loader = DataLoader(train_selection_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_composite_dataset, batch_size=64, shuffle=False)
    
    # Fine-tune the model on composite images
    print("Fine-tuning on composite images...")
    fine_tuned_model = fine_tune_model(ensemble_model, train_composite_loader, val_loader, weight_8=0.5)
    
    # Further fine-tune on selection images
    print("Fine-tuning on selection images...")
    final_model = fine_tune_model(fine_tuned_model, train_selection_loader, val_loader, weight_8=0.5)
    
    # Save the fine-tuned model
    torch.save(final_model.state_dict(), 'ensemble_finetuned_megaload.pth')
    print("Fine-tuned model saved as 'ensemble_finetuned_megaload.pth'")

if __name__ == "__main__":
    main()

  model.load_state_dict(torch.load(path))


Fine-tuning on composite images...


Epoch 1/5: 100%|██████████| 6/6 [00:00<00:00,  7.14it/s, Loss=4.96, Accuracy=21.43%] 


Accuracy of 0: 42.86%
Accuracy of 1: 37.14%
Accuracy of 2: 14.29%
Accuracy of 3: 8.57%
Accuracy of 4: 11.43%
Accuracy of 5: 28.57%
Accuracy of 6: 22.86%
Accuracy of 7: 17.14%
Accuracy of 8: 31.43%
Accuracy of 9: 0.00%
Epoch 1 Validation Accuracy: 10.00%


Epoch 2/5: 100%|██████████| 6/6 [00:00<00:00, 19.56it/s, Loss=4.87, Accuracy=20.00%]


Accuracy of 0: 31.43%
Accuracy of 1: 28.57%
Accuracy of 2: 17.14%
Accuracy of 3: 11.43%
Accuracy of 4: 11.43%
Accuracy of 5: 37.14%
Accuracy of 6: 20.00%
Accuracy of 7: 17.14%
Accuracy of 8: 25.71%
Accuracy of 9: 0.00%
Epoch 2 Validation Accuracy: 10.00%


Epoch 3/5: 100%|██████████| 6/6 [00:00<00:00, 20.23it/s, Loss=4.73, Accuracy=21.71%]


Accuracy of 0: 40.00%
Accuracy of 1: 37.14%
Accuracy of 2: 17.14%
Accuracy of 3: 11.43%
Accuracy of 4: 8.57%
Accuracy of 5: 28.57%
Accuracy of 6: 25.71%
Accuracy of 7: 20.00%
Accuracy of 8: 28.57%
Accuracy of 9: 0.00%
Epoch 3 Validation Accuracy: 9.71%


Epoch 4/5: 100%|██████████| 6/6 [00:00<00:00, 19.10it/s, Loss=4.65, Accuracy=21.71%]


Accuracy of 0: 42.86%
Accuracy of 1: 31.43%
Accuracy of 2: 14.29%
Accuracy of 3: 8.57%
Accuracy of 4: 14.29%
Accuracy of 5: 37.14%
Accuracy of 6: 25.71%
Accuracy of 7: 17.14%
Accuracy of 8: 25.71%
Accuracy of 9: 0.00%
Epoch 4 Validation Accuracy: 9.14%


Epoch 5/5: 100%|██████████| 6/6 [00:00<00:00, 17.96it/s, Loss=4.58, Accuracy=20.00%]


Accuracy of 0: 37.14%
Accuracy of 1: 31.43%
Accuracy of 2: 20.00%
Accuracy of 3: 11.43%
Accuracy of 4: 11.43%
Accuracy of 5: 28.57%
Accuracy of 6: 25.71%
Accuracy of 7: 17.14%
Accuracy of 8: 17.14%
Accuracy of 9: 0.00%
Epoch 5 Validation Accuracy: 11.71%
Fine-tuning on selection images...


Epoch 1/5: 100%|██████████| 2051/2051 [01:39<00:00, 20.66it/s, Loss=4.27, Accuracy=10.30%]


Accuracy of 0: 16.88%
Accuracy of 1: 14.96%
Accuracy of 2: 6.77%
Accuracy of 3: 9.85%
Accuracy of 4: 24.50%
Accuracy of 5: 11.72%
Accuracy of 6: 3.98%
Accuracy of 7: 10.10%
Accuracy of 8: 0.78%
Accuracy of 9: 3.44%
Epoch 1 Validation Accuracy: 14.00%


Epoch 2/5: 100%|██████████| 2051/2051 [01:38<00:00, 20.72it/s, Loss=4.19, Accuracy=10.75%]


Accuracy of 0: 17.21%
Accuracy of 1: 19.18%
Accuracy of 2: 6.19%
Accuracy of 3: 12.15%
Accuracy of 4: 17.61%
Accuracy of 5: 9.73%
Accuracy of 6: 8.00%
Accuracy of 7: 11.23%
Accuracy of 8: 0.00%
Accuracy of 9: 6.18%
Epoch 2 Validation Accuracy: 15.71%


Epoch 3/5: 100%|██████████| 2051/2051 [01:35<00:00, 21.38it/s, Loss=4.18, Accuracy=11.34%]


Accuracy of 0: 17.98%
Accuracy of 1: 22.40%
Accuracy of 2: 6.02%
Accuracy of 3: 9.68%
Accuracy of 4: 18.04%
Accuracy of 5: 10.45%
Accuracy of 6: 9.74%
Accuracy of 7: 10.93%
Accuracy of 8: 0.00%
Accuracy of 9: 8.19%
Epoch 3 Validation Accuracy: 16.00%


Epoch 4/5: 100%|██████████| 2051/2051 [01:28<00:00, 23.07it/s, Loss=4.17, Accuracy=12.01%]


Accuracy of 0: 18.58%
Accuracy of 1: 25.56%
Accuracy of 2: 6.90%
Accuracy of 3: 9.14%
Accuracy of 4: 18.58%
Accuracy of 5: 10.51%
Accuracy of 6: 11.73%
Accuracy of 7: 10.43%
Accuracy of 8: 0.00%
Accuracy of 9: 8.69%
Epoch 4 Validation Accuracy: 18.00%


Epoch 5/5: 100%|██████████| 2051/2051 [01:28<00:00, 23.30it/s, Loss=4.17, Accuracy=12.46%]


Accuracy of 0: 18.00%
Accuracy of 1: 27.71%
Accuracy of 2: 7.83%
Accuracy of 3: 7.44%
Accuracy of 4: 19.40%
Accuracy of 5: 12.14%
Accuracy of 6: 13.20%
Accuracy of 7: 11.03%
Accuracy of 8: 0.00%
Accuracy of 9: 7.81%
Epoch 5 Validation Accuracy: 18.00%
Fine-tuned model saved as 'ensemble_finetuned_megaload.pth'


In [7]:
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = class_weights

    def forward(self, outputs, targets):
        return F.cross_entropy(outputs, targets, weight=self.class_weights)

class CompositeDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def base64_to_image(self, base64_string):
        image_data = base64.b64decode(base64_string)
        image = Image.open(io.BytesIO(image_data)).convert('L')
        image_array = np.array(image, dtype=np.float32) / 255.0
        return torch.from_numpy(image_array).unsqueeze(0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = self.base64_to_image(item['composite'])
        label = torch.tensor(int(item['true_digit']), dtype=torch.long)
        return image, label

class SelectionDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = [item for sublist in data for item in sublist]

    def base64_to_image(self, base64_string):
        image_data = base64.b64decode(base64_string)
        image = Image.open(io.BytesIO(image_data)).convert('L')
        image_array = np.array(image, dtype=np.float32) / 255.0
        return torch.from_numpy(image_array).unsqueeze(0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = self.base64_to_image(item['selected_image'])
        label = torch.tensor(int(item['true_digit']), dtype=torch.long)
        return image, label

def load_data(train_folder, test_folder, num_participants=5):
    train_data = []
    for file in os.listdir(train_folder)[:num_participants]:
        with open(os.path.join(train_folder, file), 'r') as f:
            train_data.extend(json.load(f))
    
    test_data = []
    for file in os.listdir(test_folder)[:num_participants]:  # Use the same number of participants for test data
        with open(os.path.join(test_folder, file), 'r') as f:
            test_data.extend(json.load(f))
    
    return train_data, test_data

def calculate_class_weights(dataset):
    class_counts = torch.zeros(10)
    for _, label in dataset:
        class_counts[label] += 1
    class_weights = 1.0 / class_counts
    class_weights = class_weights / class_weights.sum()
    return class_weights

def fine_tune_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.0001, class_weights=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = WeightedCrossEntropyLoss(class_weights.to(device) if class_weights is not None else None)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
    
    best_val_accuracy = 0
    best_model_state = None
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            accuracy = 100 * correct / total
            progress_bar.set_postfix({'Loss': running_loss/len(progress_bar), 'Accuracy': f'{accuracy:.2f}%'})
        
        # Validation
        val_accuracy = evaluate_model(model, val_loader, device)
        print(f'Epoch {epoch+1} Validation Accuracy: {val_accuracy:.2f}%')
        
        scheduler.step(val_accuracy)
        
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_state = model.state_dict()
    
    # Load the best model state
    model.load_state_dict(best_model_state)
    return model

def evaluate_model(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

def main():
    # Load the pre-trained ensemble model
    ensemble_model = create_ensemble()
    
    # Load data
    train_data, test_data = load_data('training_data/training_set', 'training_data/test_set', num_participants=5)
    
    # Create datasets
    train_composite_dataset = CompositeDataset(test_data)
    train_selection_dataset = SelectionDataset(train_data)
    val_composite_dataset = CompositeDataset(test_data)
    
    # Calculate class weights
    class_weights = calculate_class_weights(ConcatDataset([train_composite_dataset, train_selection_dataset]))
    
    # Create data loaders
    train_composite_loader = DataLoader(train_composite_dataset, batch_size=64, shuffle=True)
    train_selection_loader = DataLoader(train_selection_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_composite_dataset, batch_size=64, shuffle=False)
    
    # Fine-tune the model on composite images
    print("Fine-tuning on composite images...")
    fine_tuned_model = fine_tune_model(ensemble_model, train_composite_loader, val_loader, class_weights=class_weights)
    
    # Further fine-tune on selection images
    print("Fine-tuning on selection images...")
    final_model = fine_tune_model(fine_tuned_model, train_selection_loader, val_loader, class_weights=class_weights)
    
    # Save the fine-tuned model
    save_model_path = 'best_boi_models_finetuned_advanced/ensemble_finetuned_megaload2.pth'
    os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
    torch.save(final_model.state_dict(), save_model_path)
    print(f"Fine-tuned model saved as '{save_model_path}'")

if __name__ == "__main__":
    main()

  model.load_state_dict(torch.load(path))


Fine-tuning on composite images...


Epoch 1/10: 100%|██████████| 28/28 [00:01<00:00, 15.79it/s, Loss=2.59, Accuracy=18.74%] 


Epoch 1 Validation Accuracy: 10.34%


Epoch 2/10: 100%|██████████| 28/28 [00:01<00:00, 22.44it/s, Loss=2.41, Accuracy=19.14%]


Epoch 2 Validation Accuracy: 19.26%


Epoch 3/10: 100%|██████████| 28/28 [00:01<00:00, 21.24it/s, Loss=2.3, Accuracy=19.37%]  


Epoch 3 Validation Accuracy: 20.57%


Epoch 4/10: 100%|██████████| 28/28 [00:01<00:00, 20.53it/s, Loss=2.24, Accuracy=20.86%] 


Epoch 4 Validation Accuracy: 21.71%


Epoch 5/10: 100%|██████████| 28/28 [00:01<00:00, 21.03it/s, Loss=2.19, Accuracy=21.49%] 


Epoch 5 Validation Accuracy: 22.51%


Epoch 6/10: 100%|██████████| 28/28 [00:01<00:00, 20.67it/s, Loss=2.15, Accuracy=22.69%] 


Epoch 6 Validation Accuracy: 24.23%


Epoch 7/10: 100%|██████████| 28/28 [00:01<00:00, 23.15it/s, Loss=2.11, Accuracy=23.89%] 


Epoch 7 Validation Accuracy: 25.31%


Epoch 8/10: 100%|██████████| 28/28 [00:01<00:00, 21.89it/s, Loss=2.09, Accuracy=25.60%] 


Epoch 8 Validation Accuracy: 26.91%


Epoch 9/10: 100%|██████████| 28/28 [00:01<00:00, 22.26it/s, Loss=2.05, Accuracy=26.29%] 


Epoch 9 Validation Accuracy: 28.23%


Epoch 10/10: 100%|██████████| 28/28 [00:01<00:00, 21.28it/s, Loss=2.02, Accuracy=27.89%] 


Epoch 10 Validation Accuracy: 29.31%
Fine-tuning on selection images...


Epoch 1/10: 100%|██████████| 10254/10254 [07:56<00:00, 21.50it/s, Loss=2.3, Accuracy=11.33%] 


Epoch 1 Validation Accuracy: 21.71%


Epoch 2/10: 100%|██████████| 10254/10254 [07:58<00:00, 21.44it/s, Loss=2.29, Accuracy=12.88%]


Epoch 2 Validation Accuracy: 34.86%


Epoch 3/10: 100%|██████████| 10254/10254 [07:55<00:00, 21.56it/s, Loss=2.28, Accuracy=13.50%]


Epoch 3 Validation Accuracy: 45.66%


Epoch 4/10: 100%|██████████| 10254/10254 [07:57<00:00, 21.47it/s, Loss=2.27, Accuracy=13.97%]


Epoch 4 Validation Accuracy: 48.46%


Epoch 5/10: 100%|██████████| 10254/10254 [08:01<00:00, 21.32it/s, Loss=2.27, Accuracy=14.32%]


Epoch 5 Validation Accuracy: 55.83%


Epoch 6/10: 100%|██████████| 10254/10254 [07:59<00:00, 21.40it/s, Loss=2.26, Accuracy=14.63%]


Epoch 6 Validation Accuracy: 54.63%


Epoch 7/10: 100%|██████████| 10254/10254 [08:02<00:00, 21.27it/s, Loss=2.26, Accuracy=14.98%]


Epoch 7 Validation Accuracy: 54.91%


Epoch 8/10: 100%|██████████| 10254/10254 [07:56<00:00, 21.51it/s, Loss=2.25, Accuracy=15.32%]


Epoch 8 Validation Accuracy: 52.57%


Epoch 9/10: 100%|██████████| 10254/10254 [07:56<00:00, 21.52it/s, Loss=2.25, Accuracy=15.84%]


Epoch 9 Validation Accuracy: 50.11%


Epoch 10/10: 100%|██████████| 10254/10254 [07:55<00:00, 21.57it/s, Loss=2.24, Accuracy=16.10%]


Epoch 10 Validation Accuracy: 49.89%
Fine-tuned model saved as 'best_boi_models_finetuned_advanced/ensemble_finetuned_megaload2.pth'


Oh baby this is going to be interesting

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms
import numpy as np
from PIL import Image
import io
import base64
import json
import os
from tqdm import tqdm
from mnist_skeptic_v9 import skeptic_v9

In [12]:
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = class_weights

    def forward(self, outputs, targets):
        return F.cross_entropy(outputs, targets, weight=self.class_weights)

class CompositeDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def base64_to_image(self, base64_string):
        image_data = base64.b64decode(base64_string)
        image = Image.open(io.BytesIO(image_data)).convert('L')
        image_array = np.array(image, dtype=np.float32) / 255.0
        return torch.from_numpy(image_array).unsqueeze(0)  # Add channel dimension

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = self.base64_to_image(item['composite'])
        label = torch.tensor(int(item['true_digit']), dtype=torch.long)
        return image, label

class SelectionDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = [item for sublist in data for item in sublist]

    def base64_to_image(self, base64_string):
        image_data = base64.b64decode(base64_string)
        image = Image.open(io.BytesIO(image_data)).convert('L')
        image_array = np.array(image, dtype=np.float32) / 255.0
        return torch.from_numpy(image_array).unsqueeze(0)  # Add channel dimension

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = self.base64_to_image(item['selected_image'])
        label = torch.tensor(int(item['true_digit']), dtype=torch.long)
        return image, label

def load_data(train_folder, test_folder, num_participants=10):
    train_data = []
    for file in os.listdir(train_folder)[:num_participants]:
        with open(os.path.join(train_folder, file), 'r') as f:
            train_data.extend(json.load(f))
    
    test_data = []
    for file in os.listdir(test_folder)[:num_participants]:
        with open(os.path.join(test_folder, file), 'r') as f:
            test_data.extend(json.load(f))
    
    return train_data, test_data

def fine_tune_model(model, train_loader, val_loader, num_epochs=12, learning_rate=0.0001, class_weights=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = WeightedCrossEntropyLoss(class_weights.to(device) if class_weights is not None else None)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_val_accuracy = 0
    best_model_state = None
    patience = 5
    no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            accuracy = 100 * correct / total
            progress_bar.set_postfix({'Loss': running_loss/len(progress_bar), 'Accuracy': f'{accuracy:.2f}%'})
        
        # Validation
        val_accuracy = evaluate_model(model, val_loader, device)
        print(f'Epoch {epoch+1} Validation Accuracy: {val_accuracy:.2f}%')
        
        scheduler.step()
        
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_state = model.state_dict()
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stopping triggered. No improvement for {patience} epochs.")
                break
    
    # Load the best model state
    model.load_state_dict(best_model_state)
    return model

def main():
    # Load the pre-trained ensemble model
    ensemble_model = create_ensemble()
    
    # Load data
    train_data, test_data = load_data('training_data/training_set', 'training_data/test_set', num_participants=10)
    
    # Create datasets
    train_composite_dataset = CompositeDataset(test_data)
    train_selection_dataset = SelectionDataset(train_data)
    val_composite_dataset = CompositeDataset(test_data)
    
    # Calculate class weights
    class_weights = calculate_class_weights(ConcatDataset([train_composite_dataset, train_selection_dataset]))
    
    # Create data loaders
    train_composite_loader = DataLoader(train_composite_dataset, batch_size=64, shuffle=True)
    train_selection_loader = DataLoader(train_selection_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_composite_dataset, batch_size=64, shuffle=False)
    
    # Fine-tune the model on composite images
    print("Fine-tuning on composite images...")
    fine_tuned_model = fine_tune_model(ensemble_model, train_composite_loader, val_loader, class_weights=class_weights)
    
    # Further fine-tune on selection images
    print("Fine-tuning on selection images...")
    final_model = fine_tune_model(fine_tuned_model, train_selection_loader, val_loader, class_weights=class_weights)
    
    # Save the fine-tuned model
    save_model_path = 'best_boi_models_finetuned_advanced/ensemble_finetuned_megaload3.pth'
    os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
    torch.save(final_model.state_dict(), save_model_path)
    print(f"Fine-tuned model saved as '{save_model_path}'")

if __name__ == "__main__":
    main()

  model.load_state_dict(torch.load(path))


Fine-tuning on composite images...


Epoch 1/12: 100%|██████████| 55/55 [00:02<00:00, 20.95it/s, Loss=2.52, Accuracy=18.71%] 


Epoch 1 Validation Accuracy: 19.40%


Epoch 2/12: 100%|██████████| 55/55 [00:02<00:00, 20.09it/s, Loss=2.3, Accuracy=19.94%]  


Epoch 2 Validation Accuracy: 21.46%


Epoch 3/12: 100%|██████████| 55/55 [00:02<00:00, 20.18it/s, Loss=2.19, Accuracy=21.71%] 


Epoch 3 Validation Accuracy: 23.20%


Epoch 4/12: 100%|██████████| 55/55 [00:02<00:00, 20.93it/s, Loss=2.13, Accuracy=23.03%] 


Epoch 4 Validation Accuracy: 24.83%


Epoch 5/12: 100%|██████████| 55/55 [00:02<00:00, 21.13it/s, Loss=2.08, Accuracy=24.86%] 


Epoch 5 Validation Accuracy: 26.49%


Epoch 6/12: 100%|██████████| 55/55 [00:02<00:00, 20.19it/s, Loss=2.05, Accuracy=26.26%] 


Epoch 6 Validation Accuracy: 27.91%


Epoch 7/12: 100%|██████████| 55/55 [00:02<00:00, 21.44it/s, Loss=2.01, Accuracy=27.66%] 


Epoch 7 Validation Accuracy: 29.09%


Epoch 8/12: 100%|██████████| 55/55 [00:02<00:00, 20.78it/s, Loss=1.99, Accuracy=29.17%] 


Epoch 8 Validation Accuracy: 29.74%


Epoch 9/12: 100%|██████████| 55/55 [00:02<00:00, 21.29it/s, Loss=1.97, Accuracy=29.89%] 


Epoch 9 Validation Accuracy: 30.37%


Epoch 10/12: 100%|██████████| 55/55 [00:02<00:00, 19.68it/s, Loss=1.96, Accuracy=31.11%] 


Epoch 10 Validation Accuracy: 31.00%


Epoch 11/12: 100%|██████████| 55/55 [00:02<00:00, 21.09it/s, Loss=1.95, Accuracy=30.51%] 


Epoch 11 Validation Accuracy: 30.86%


Epoch 12/12: 100%|██████████| 55/55 [00:02<00:00, 21.68it/s, Loss=1.95, Accuracy=30.63%] 


Epoch 12 Validation Accuracy: 30.97%
Fine-tuning on selection images...


Epoch 1/12: 100%|██████████| 20508/20508 [16:37<00:00, 20.56it/s, Loss=2.29, Accuracy=12.18%]


Epoch 1 Validation Accuracy: 40.71%


Epoch 2/12: 100%|██████████| 20508/20508 [16:37<00:00, 20.56it/s, Loss=2.28, Accuracy=13.54%]


Epoch 2 Validation Accuracy: 42.77%


Epoch 3/12: 100%|██████████| 20508/20508 [16:33<00:00, 20.64it/s, Loss=2.27, Accuracy=13.90%]


Epoch 3 Validation Accuracy: 50.43%


Epoch 4/12: 100%|██████████| 20508/20508 [16:34<00:00, 20.61it/s, Loss=2.27, Accuracy=14.04%]


Epoch 4 Validation Accuracy: 54.06%


Epoch 5/12: 100%|██████████| 20508/20508 [16:37<00:00, 20.57it/s, Loss=2.27, Accuracy=14.21%]


Epoch 5 Validation Accuracy: 54.46%


Epoch 6/12: 100%|██████████| 20508/20508 [16:36<00:00, 20.58it/s, Loss=2.27, Accuracy=14.41%]


Epoch 6 Validation Accuracy: 55.46%


Epoch 7/12: 100%|██████████| 20508/20508 [16:38<00:00, 20.54it/s, Loss=2.26, Accuracy=14.53%]


Epoch 7 Validation Accuracy: 55.37%


Epoch 8/12: 100%|██████████| 20508/20508 [16:37<00:00, 20.57it/s, Loss=2.26, Accuracy=14.71%]


Epoch 8 Validation Accuracy: 57.63%


Epoch 9/12: 100%|██████████| 20508/20508 [16:33<00:00, 20.65it/s, Loss=2.26, Accuracy=14.79%]


Epoch 9 Validation Accuracy: 58.14%


Epoch 10/12: 100%|██████████| 20508/20508 [16:36<00:00, 20.59it/s, Loss=2.26, Accuracy=14.96%]


Epoch 10 Validation Accuracy: 58.23%


Epoch 11/12: 100%|██████████| 20508/20508 [16:35<00:00, 20.60it/s, Loss=2.26, Accuracy=15.01%]


Epoch 11 Validation Accuracy: 57.54%


Epoch 12/12: 100%|██████████| 20508/20508 [16:35<00:00, 20.59it/s, Loss=2.26, Accuracy=15.08%]


Epoch 12 Validation Accuracy: 57.80%
Fine-tuned model saved as 'best_boi_models_finetuned_advanced/ensemble_finetuned_megaload3.pth'
