In [2]:
import copy
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from tqdm import tqdm
from typing import List, Tuple
from dataclasses import dataclass
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [3]:
@dataclass
class Hyperparameters:
    width_mult: float
    learning_rate: float
    batch_size: int
    dropout_rate: float
    weight_decay: float
    momentum: float
    conv_channels: List[int]
    
    def __post_init__(self):
        self.width_mult = max(0.25, min(2.0, self.width_mult))
        self.learning_rate = max(0.0001, min(0.1, self.learning_rate))
        self.batch_size = max(32, min(256, int(self.batch_size)))
        self.dropout_rate = max(0.0, min(0.5, self.dropout_rate))
        self.weight_decay = max(1e-6, min(1e-2, self.weight_decay))
        self.momentum = max(0.1, min(0.99, self.momentum))

In [4]:
class FEMNISTDataset(Dataset):
    def __init__(self, hf_split, transform):
        self.data = hf_split
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        img = item['image']
        label = int(item['character'])
        if self.transform:
            img = self.transform(img)
        return img, label

In [5]:
class SeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, stride=1, padding=1, dropout_rate=0.0):
        super().__init__()
        self.dw = nn.Conv2d(in_ch, in_ch, kernel, stride, padding, groups=in_ch, bias=False)
        self.pw = nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
        
    def forward(self, x):
        x = self.dw(x)
        x = self.pw(x)
        x = self.bn(x)
        x = F.relu(x, inplace=True)
        return self.dropout(x)

In [6]:
class CNN(nn.Module):
    def __init__(self, num_classes=62, hyperparams=None):
        super().__init__()
        if hyperparams is None:
            hyperparams = Hyperparameters(
                width_mult=1.0, learning_rate=0.01, batch_size=128,
                dropout_rate=0.2, weight_decay=2e-4, momentum=0.9,
                conv_channels=[16, 32, 64, 128]
            )
        
        self.hyperparams = hyperparams
        def c(ch): return max(8, int(ch * hyperparams.width_mult))
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, c(hyperparams.conv_channels[0]), 3, 1, 1, bias=False),
            nn.BatchNorm2d(c(hyperparams.conv_channels[0])),
            nn.ReLU(inplace=True),
            nn.Dropout2d(hyperparams.dropout_rate * 0.5)
        )
        
        self.sep1 = SeparableConv(
            c(hyperparams.conv_channels[0]), c(hyperparams.conv_channels[1]), 
            stride=2, dropout_rate=hyperparams.dropout_rate * 0.7
        )
        self.sep2 = SeparableConv(
            c(hyperparams.conv_channels[1]), c(hyperparams.conv_channels[2]), 
            stride=2, dropout_rate=hyperparams.dropout_rate * 0.8
        )
        self.sep3 = SeparableConv(
            c(hyperparams.conv_channels[2]), c(hyperparams.conv_channels[3]), 
            stride=2, dropout_rate=hyperparams.dropout_rate
        )
        
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.dropout_fc = nn.Dropout(hyperparams.dropout_rate)
        self.fc = nn.Linear(c(hyperparams.conv_channels[3]), num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.sep1(x)
        x = self.sep2(x)
        x = self.sep3(x)
        x = self.pool(x)
        x = self.dropout_fc(x.view(x.size(0), -1))
        return self.fc(x)

In [7]:
class GeneticAlgorithm:
    def __init__(self, population_size=8, mutation_rate=0.3, crossover_rate=0.7):
        self.population_size = population_size
        self.mutation_rate = mutation_rate
        self.crossover_rate = crossover_rate
        self.generation = 0
        
    def create_individual(self):
        return Hyperparameters(
            width_mult=random.uniform(0.5, 1.5),
            learning_rate=random.uniform(0.005, 0.05),
            batch_size=random.choice([64, 96, 128, 160, 192]),
            dropout_rate=random.uniform(0.1, 0.4),
            weight_decay=random.uniform(1e-5, 5e-3),
            momentum=random.uniform(0.8, 0.95),
            conv_channels=[16, 32, 64, 128]
        )
    
    def create_population(self):
        population = []
        baseline = Hyperparameters(
            width_mult=1.0, learning_rate=0.01, batch_size=128,
            dropout_rate=0.2, weight_decay=2e-4, momentum=0.9,
            conv_channels=[16, 32, 64, 128]
        )
        population.append(baseline)
        
        for _ in range(self.population_size - 1):
            population.append(self.create_individual())
        return population
    
    def crossover(self, parent1, parent2) :
        child1_dict = {}
        child2_dict = {}
        
        for key in parent1.__dict__.keys():
            if random.random() < 0.5:
                child1_dict[key] = getattr(parent1, key)
                child2_dict[key] = getattr(parent2, key)
            else:
                child1_dict[key] = getattr(parent2, key)
                child2_dict[key] = getattr(parent1, key)
        
        child1 = Hyperparameters(**child1_dict)
        child2 = Hyperparameters(**child2_dict)
        return child1, child2
    
    def mutate(self, individual):
        mutated = copy.deepcopy(individual)
        
        if random.random() < self.mutation_rate:
            mutated.width_mult += random.gauss(0, 0.2)
        if random.random() < self.mutation_rate:
            mutated.learning_rate *= random.uniform(0.5, 2.0)
        if random.random() < self.mutation_rate:
            mutated.batch_size = random.choice([64, 96, 128, 160, 192])
        if random.random() < self.mutation_rate:
            mutated.dropout_rate += random.gauss(0, 0.1)
        if random.random() < self.mutation_rate:
            mutated.weight_decay *= random.uniform(0.1, 10.0)
        if random.random() < self.mutation_rate:
            mutated.momentum += random.gauss(0, 0.05)
            
        mutated.__post_init__()
        return mutated
    
    def select_parents(self, population, fitness_scores):
        selected = []
        tournament_size = 3
        
        for _ in range(len(population)):
            tournament_indices = random.sample(range(len(population)), min(tournament_size, len(population)))
            tournament_fitness = [fitness_scores[i] for i in tournament_indices]
            winner_idx = tournament_indices[np.argmax(tournament_fitness)]
            selected.append(copy.deepcopy(population[winner_idx]))
        
        return selected

In [8]:
def train_and_evaluate_model(hyperparams, train_loader, val_loader, device, max_epochs=3):
    try:
        model = CNN(num_classes=62, hyperparams=hyperparams).to(device)
        
        optimizer = torch.optim.SGD(
            model.parameters(), 
            lr=hyperparams.learning_rate,
            momentum=hyperparams.momentum,
            weight_decay=hyperparams.weight_decay
        )
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
        
        best_val_acc = 0.0
        
        for epoch in range(max_epochs):
            model.train()
            for batch_idx, (data, target) in enumerate(train_loader):
                if batch_idx > 200:  # Limit training for GA speed
                    break
                    
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
            
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for batch_idx, (data, target) in enumerate(val_loader):
                    if batch_idx > 50:
                        break
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    pred = output.argmax(dim=1)
                    correct += pred.eq(target).sum().item()
                    total += target.size(0)
            
            val_acc = correct / total
            best_val_acc = max(best_val_acc, val_acc)
            scheduler.step()
        
        return best_val_acc
    
    except Exception as e:
        print(f"Error training model: {e}")
        return 0.0

In [9]:
print("Loading FEMNIST dataset...")
ds = load_dataset("flwrlabs/femnist")

transform = T.Compose([
    T.Resize((28,28)),
    T.Grayscale(num_output_channels=1),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

temp_split = ds['train'].train_test_split(test_size=0.1, seed=42)
test_split = temp_split['test']

train_valid = temp_split['train'].train_test_split(test_size=0.1, seed=42)

ga_train_size = 20000
ga_val_size = 5000

ga_train_data = train_valid['train'].select(range(min(ga_train_size, len(train_valid['train']))))
ga_val_data = train_valid['test'].select(range(min(ga_val_size, len(train_valid['test']))))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Loading FEMNIST dataset...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/201M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/814277 [00:00<?, ? examples/s]

Using device: cuda


In [10]:
ga = GeneticAlgorithm(population_size=6, mutation_rate=0.3, crossover_rate=0.7)
population = ga.create_population()

generations = 5
best_fitness_history = []
best_individual = None
best_fitness = 0.0

In [11]:
print(f"Starting GA optimization with {len(population)} individuals for {generations} generations...")

for generation in range(generations):
    print(f"\n---Generation {generation + 1}/{generations}---")
    fitness_scores = []
    
    for i, individual in enumerate(population):
        print(f"Evaluating individual {i+1}/{len(population)}...")
        
        train_ds = FEMNISTDataset(ga_train_data, transform)
        val_ds = FEMNISTDataset(ga_val_data, transform)
        
        train_loader = DataLoader(train_ds, batch_size=individual.batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_ds, batch_size=individual.batch_size, shuffle=False, num_workers=2)
        
        fitness = train_and_evaluate_model(individual, train_loader, val_loader, device)
        fitness_scores.append(fitness)
        
        print(f"->Fitness: {fitness:.4f}")
        print(f"->Params: learning_rate={individual.learning_rate:.4f}, weight_decay={individual.weight_decay:.6f}, "
              f"dropout_rate={individual.dropout_rate:.3f}, width_mult={individual.width_mult:.3f}")
        
        if fitness > best_fitness:
            best_fitness = fitness
            best_individual = copy.deepcopy(individual)
    
    avg_fitness = np.mean(fitness_scores)
    max_fitness = np.max(fitness_scores)
    best_fitness_history.append(max_fitness)
    
    print(f"Generation {generation + 1} - Avg: {avg_fitness:.4f}, Best: {max_fitness:.4f}")
    
    if generation < generations - 1:
        selected_parents = ga.select_parents(population, fitness_scores)
        new_population = []
        
        best_idx = np.argmax(fitness_scores)
        new_population.append(copy.deepcopy(population[best_idx]))
        
        while len(new_population) < ga.population_size:
            parent1, parent2 = random.sample(selected_parents, 2)
            if random.random() < ga.crossover_rate:
                child1, child2 = ga.crossover(parent1, parent2)
            else:
                child1, child2 = copy.deepcopy(parent1), copy.deepcopy(parent2)
            
            child1 = ga.mutate(child1)
            child2 = ga.mutate(child2)
            
            new_population.extend([child1, child2])
        
        population = new_population[:ga.population_size]

print(f"\n---GA Optimization Done---")
print(f"Best fitness: {best_fitness:.4f}")
print(f"Best hyperparams:")
print(f"->Learning rate: {best_individual.learning_rate:.6f}")
print(f"->Weight decay: {best_individual.weight_decay:.6f}")
print(f"->Dropout rate: {best_individual.dropout_rate:.4f}")
print(f"->Width multiplier: {best_individual.width_mult:.4f}")
print(f"->Batch size: {best_individual.batch_size}")
print(f"->Momentum: {best_individual.momentum:.4f}")

Starting GA optimization with 6 individuals for 5 generations...

---Generation 1/5---
Evaluating individual 1/6...
->Fitness: 0.3042
->Params: learning_rate=0.0100, weight_decay=0.000200, dropout_rate=0.200, width_mult=1.000
Evaluating individual 2/6...
->Fitness: 0.1936
->Params: learning_rate=0.0061, weight_decay=0.000706, dropout_rate=0.173, width_mult=1.139
Evaluating individual 3/6...
->Fitness: 0.3690
->Params: learning_rate=0.0295, weight_decay=0.000159, dropout_rate=0.227, width_mult=1.241
Evaluating individual 4/6...
->Fitness: 0.2355
->Params: learning_rate=0.0062, weight_decay=0.003510, dropout_rate=0.315, width_mult=1.005
Evaluating individual 5/6...
->Fitness: 0.3355
->Params: learning_rate=0.0175, weight_decay=0.000807, dropout_rate=0.328, width_mult=0.949
Evaluating individual 6/6...
->Fitness: 0.3522
->Params: learning_rate=0.0147, weight_decay=0.001906, dropout_rate=0.131, width_mult=0.778
Generation 1 - Avg: 0.2983, Best: 0.3690

---Generation 2/5---
Evaluating indiv

In [14]:
full_train_ds = FEMNISTDataset(train_valid['train'], transform)
full_val_ds = FEMNISTDataset(train_valid['test'], transform)
full_test_ds = FEMNISTDataset(test_split, transform)

train_loader = DataLoader(full_train_ds, batch_size=best_individual.batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(full_val_ds, batch_size=best_individual.batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(full_test_ds, batch_size=best_individual.batch_size, shuffle=False, num_workers=2)

final_model = CNN(num_classes=62, hyperparams=best_individual).to(device)
optimizer = torch.optim.SGD(
    final_model.parameters(),
    lr=best_individual.learning_rate,
    momentum=best_individual.momentum,
    weight_decay=best_individual.weight_decay
)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

print(f"Final model parameters: {sum(p.numel() for p in final_model.parameters())}")

Final model parameters: 54396


In [15]:
print(f"\nTraining Model with optimised hyperparams...")

epochs = 10
best_val_acc = 0.0

for epoch in range(epochs):
    final_model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = final_model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1)
        train_correct += pred.eq(target).sum().item()
        train_total += target.size(0)
        
        if batch_idx % 500 == 0:
            pbar.set_postfix({
                'loss': f'{train_loss/(batch_idx+1):.4f}',
                'acc': f'{100.*train_correct/train_total:.2f}%'
            })
    
    final_model.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = final_model(data)
            loss = criterion(output, target)
            val_loss += loss.item()
            pred = output.argmax(dim=1)
            val_correct += pred.eq(target).sum().item()
            val_total += target.size(0)
    
    val_acc = val_correct / val_total
    train_acc = train_correct / train_total
    
    print(f"Epoch {epoch+1}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(final_model.state_dict(), "ga_optimized_model.pth")
    
    scheduler.step()


Training Model with optimised hyperparams...


Epoch 1/10: 100%|██████████| 10306/10306 [03:15<00:00, 52.81it/s, loss=0.8422, acc=74.83%]


Epoch 1: Train Acc: 0.7497, Val Acc: 0.8338


Epoch 2/10: 100%|██████████| 10306/10306 [03:13<00:00, 53.23it/s, loss=0.6214, acc=80.20%]


Epoch 2: Train Acc: 0.8021, Val Acc: 0.8391


Epoch 3/10: 100%|██████████| 10306/10306 [03:13<00:00, 53.16it/s, loss=0.5924, acc=81.00%]


Epoch 3: Train Acc: 0.8101, Val Acc: 0.8461


Epoch 4/10: 100%|██████████| 10306/10306 [03:14<00:00, 52.89it/s, loss=0.5746, acc=81.54%]


Epoch 4: Train Acc: 0.8153, Val Acc: 0.8424


Epoch 5/10: 100%|██████████| 10306/10306 [03:15<00:00, 52.67it/s, loss=0.5574, acc=82.05%]


Epoch 5: Train Acc: 0.8204, Val Acc: 0.8483


Epoch 6/10: 100%|██████████| 10306/10306 [03:15<00:00, 52.60it/s, loss=0.5431, acc=82.51%]


Epoch 6: Train Acc: 0.8250, Val Acc: 0.8516


Epoch 7/10: 100%|██████████| 10306/10306 [03:15<00:00, 52.73it/s, loss=0.5277, acc=82.92%]


Epoch 7: Train Acc: 0.8290, Val Acc: 0.8559


Epoch 8/10: 100%|██████████| 10306/10306 [03:13<00:00, 53.30it/s, loss=0.5126, acc=83.38%]


Epoch 8: Train Acc: 0.8338, Val Acc: 0.8599


Epoch 9/10: 100%|██████████| 10306/10306 [03:13<00:00, 53.35it/s, loss=0.4955, acc=83.87%]


Epoch 9: Train Acc: 0.8387, Val Acc: 0.8612


Epoch 10/10: 100%|██████████| 10306/10306 [03:11<00:00, 53.77it/s, loss=0.4850, acc=84.13%]


Epoch 10: Train Acc: 0.8414, Val Acc: 0.8637


In [16]:
final_model.load_state_dict(torch.load("ga_optimized_model.pth", map_location=device))
final_model.eval()

test_correct = 0
test_total = 0

with torch.no_grad():
    for data, target in tqdm(test_loader, desc="Final Test"):
        data, target = data.to(device), target.to(device)
        output = final_model(data)
        pred = output.argmax(dim=1)
        test_correct += pred.eq(target).sum().item()
        test_total += target.size(0)

final_test_acc = test_correct / test_total

print(f"\nResults:")
print(f"Final Test Accuracy: {final_test_acc:.4f}")
print(f"Model saved as 'ga_optimized_model.pth'")

Final Test: 100%|██████████| 1273/1273 [00:23<00:00, 54.97it/s]


Results:
Final Test Accuracy: 0.8638
Model saved as 'ga_optimized_model.pth'



