# Model Training Debug Notebook

This notebook tests the full pipeline:
1.  Imports the `LeNet` model from `model.py`.
2.  Imports the `create_dataloader` function from `data_loader.py`.
3.  Imports the `train` function from `training.py`.
4.  Sets up data paths for a client and a validation set.
5.  Instantiates the model and data loaders.
6.  Calls the `train` function to run a simple training loop.

In [1]:
import torch
import os
from utils.model import LeNet
from utils.data_loader import create_dataloader
from utils.training import train

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.5.1+cu124
CUDA available: False


## 1. Define Paths and Hyperparameters

In [2]:
# --- Configuration ---

# Point this to your main sharded data directory
BASE_DIR = "./sharded_data"

# We will train on client 0's data
TRAIN_DATA_DIR = os.path.join(BASE_DIR, "client_0")

# We will validate on the 'test1' set (which was 'val1')
VAL_DATA_DIR = os.path.join(BASE_DIR, "test1")

# Hyperparameters
EPOCHS = 5
LEARNING_RATE = 0.001
BATCH_SIZE = 32

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Training data: {TRAIN_DATA_DIR}")
print(f"Validation data: {VAL_DATA_DIR}")
print(f"Running on device: {device}")

Training data: ./sharded_data/client_0
Validation data: ./sharded_data/test1
Running on device: cpu


## 2. Create DataLoaders

In [3]:
print("Creating training dataloader...")
train_loader = create_dataloader(
    data_dir=TRAIN_DATA_DIR,
    batch_size=BATCH_SIZE,
    shuffle=True
)

print("Creating validation dataloader...")
val_loader = create_dataloader(
    data_dir=VAL_DATA_DIR,
    batch_size=BATCH_SIZE,
    shuffle=False
)

if train_loader and val_loader:
    print(f"Loaded {len(train_loader.dataset)} training samples.")
    print(f"Loaded {len(val_loader.dataset)} validation samples.")
else:
    print("Error creating dataloaders. Check paths and data.")

Creating training dataloader...
Creating validation dataloader...
Loaded 1210 training samples.
Loaded 2000 validation samples.


## 3. Initialize Model

In [4]:
model = LeNet()
print(f"Model initialized: {model.__class__.__name__}")
print(f"Total parameters: {model.count_params()}")

Model initialized: LeNet
Total parameters: 2086664


## 4. Run Training

In [5]:
if train_loader and val_loader:
    train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=EPOCHS,
        learning_rate=LEARNING_RATE,
        device=device,
        val_frequency=3
    )
else:
    print("Skipping training because dataloaders were not loaded.")

Epoch 1/5 [Train]:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 2/5 [Train]:   0%|          | 0/37 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x76daf80b32e0>
Traceback (most recent call last):
  File "/home/rasmus/miniconda3/envs/hackathon/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Exception ignored in: Exception ignored in:     self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x76daf80b32e0><function _MultiProcessingDataLoaderIter.__del__ at 0x76daf80b32e0>

  File "/home/rasmus/miniconda3/envs/hackathon/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

Traceback (most recent call last):
      File "/home/rasmus/miniconda3/envs/hackathon/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Traceback (most recent call last):
    if w.is_alive():  File "/home/rasmus/miniconda3/envs/hackathon/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
self._shutdown_workers()
     
  File 

Epoch 3/5 [Train]:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 3/5 [Val]:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 4/5 [Train]:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 5/5 [Train]:   0%|          | 0/37 [00:00<?, ?it/s]

## 5. Simple Verification

Let's grab one batch from the validation loader and check the output shape.

In [None]:
if val_loader:
    model.eval() # Set model to evaluation mode
    model.to(device)

    # Get one batch of validation data
    data, labels = next(iter(val_loader))
    data, labels = data.to(device), labels.to(device)

    with torch.no_grad():
        outputs = model(data)
    
    print(f"Input data shape: {data.shape}")
    print(f"Output logits shape: {outputs.shape}")
    print(f"(Batch size should match output's first dimension: {BATCH_SIZE})")
    print(f"(Output's second dimension should match NUM_CLASSES: {model.fc1.out_features})")
else:
    print("Cannot verify, val_loader not loaded.")