# ETL Galaxy Zoo 2 - Model training

by BRAUX Owen and CAMBIER Elliot in 2026

## Imports 

In [None]:
import os
import random
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split

## Configuration

In [None]:
PROJECT_NAME = "Galaxy_ResNet18_M2"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
IMG_SIZE = 128
EPOCHS = 15         
LEARNING_RATE = 1e-3
SEED = 42

# - Kaggle : "/kaggle/input/galaxy-dataset-brauxo/" to complete ! (private dataset so you need to create your own)
# - Local : "dataset_images"
DATA_DIR = "dataset_images" 

def set_seed(seed=42):
    """Set seed for perfect reproductibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(SEED)

## Data loader

In [None]:
def get_data_loaders(data_dir, batch_size, img_size):
    train_transforms = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomRotation(180),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        # Normalisation ImageNet for transfer learning !!!
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Validation (no augment)
    val_transforms = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    try:
        full_dataset = datasets.ImageFolder(root=data_dir)
    except FileNotFoundError:
        print(f"ERREUR : Le dossier '{data_dir}' est introuvable.")
        return None, None, None

    # Split Train/Val (80% / 20%)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    # Split will always be the same
    train_subset, val_subset = random_split(
        full_dataset, [train_size, val_size], 
        generator=torch.Generator().manual_seed(SEED)
    )
    

    train_subset.dataset.transform = train_transforms 
    
    train_ds = datasets.ImageFolder(root=data_dir, transform=train_transforms)
    val_ds = datasets.ImageFolder(root=data_dir, transform=val_transforms)
    
    # Takes indices of the last split randomly
    train_indices = train_subset.indices
    val_indices = val_subset.indices
    
    train_data = torch.utils.data.Subset(train_ds, train_indices)
    val_data = torch.utils.data.Subset(val_ds, val_indices)

    # Loaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    return train_loader, val_loader, full_dataset.classes

train_loader, val_loader, class_names = get_data_loaders(DATA_DIR, BATCH_SIZE, IMG_SIZE)

print(f"Classes détectées : {class_names}")
print(f"Train size : {len(train_loader.dataset)} images")
print(f"Val size   : {len(val_loader.dataset)} images")

## Model

In [None]:
class GalaxyResNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.backbone(x)
    
model = GalaxyResNet(num_classes=len(class_names)).to(DEVICE)
print("Modèle ResNet18 initialisé et envoyé sur GPU.")

## Training

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = train_loader
            else:
                model.eval()
                dataloader = val_loader

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloader:
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)

                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + Optimize
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Stats
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())

            # Deep Copy of best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), 'best_galaxy_resnet.pth')

        print()

    time_elapsed = time.time() - since
    print(f'Entraînement terminé en {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Meilleure Val Acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    return model, history

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

model, history = train_model(model, train_loader, val_loader, criterion, optimizer, exp_lr_scheduler, num_epochs=EPOCHS)

## Analysis (plotting)

In [None]:
def plot_history(history):
    acc = history['train_acc']
    val_acc = history['val_acc']
    loss = history['train_loss']
    val_loss = history['val_loss']
    epochs = range(1, len(acc) + 1)

    plt.figure(figsize=(14, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, acc, 'b-', label='Training Acc')
    plt.plot(epochs, val_acc, 'r-', label='Validation Acc')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss, 'b-', label='Training Loss')
    plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_confusion_matrix(model, loader, class_names):
    y_pred = []
    y_true = []
    model.eval()
    
    print("Calcul de la matrice de confusion...")
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            y_pred.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
            
    cm = confusion_matrix(y_true, y_pred)
    clean_names = [c.split('_', 1)[1] if '_' in c else c for c in class_names]
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=clean_names, yticklabels=clean_names)
    plt.ylabel('Vraie Classe')
    plt.xlabel('Classe Prédite')
    plt.title('Matrice de Confusion')
    plt.xticks(rotation=45, ha='right')
    plt.show()

    print(classification_report(y_true, y_pred, target_names=clean_names))

plot_history(history)
plot_confusion_matrix(model, val_loader, class_names)

## Visualisation :

In [None]:
def visualize_predictions(model, loader, num_images=10):
    model.eval()
    images_so_far = 0
    plt.figure(figsize=(15, 10))
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(loader):
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            for j in range(inputs.size(0)):
                images_so_far += 1
                ax = plt.subplot(2, 5, images_so_far)
                ax.axis('off')
                
                img = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                img = std * img + mean
                img = np.clip(img, 0, 1)
                
                ax.imshow(img)
                true_label = class_names[labels[j]].split('_', 1)[1]
                pred_label = class_names[preds[j]].split('_', 1)[1]
                
                color = 'green' if preds[j] == labels[j] else 'red'
                ax.set_title(f"True: {true_label}\nPred: {pred_label}", color=color, fontsize=8)
                
                if images_so_far == num_images:
                    return

visualize_predictions(model, val_loader, num_images=10)