# Run training code in Colab

This notebook recreates the `neural_network_analysis` package, downloads the Tiny-ImageNet dataset, and runs the modular training runner.

Run cells top-to-bottom in Colab.

If you want to use Google Drive for persistent storage, mount Drive and change the paths accordingly.

In [None]:
# Install necessary packages
# Colab already provides a matching torch; we install utilities and albumentations
!pip install -q albumentations==1.3.0 albumentations-pytorch torchsummary huggingface_hub tqdm

In [None]:
# Mount Google Drive for persistent storage (optional)
from google.colab import drive
drive.mount('/content/drive')

# Download and extract Tiny-ImageNet (v1.0) into /content/tiny-imagenet-200
import os
os.makedirs('/content/tiny-imagenet-200', exist_ok=True)
# Download the Tiny-ImageNet zip (approx 250MB)
!wget -q http://cs231n.stanford.edu/tiny-imagenet-200.zip -O /content/tiny-imagenet-200/tiny-imagenet-200.zip
# Unzip the dataset
!unzip -q /content/tiny-imagenet-200/tiny-imagenet-200.zip -d /content/tiny-imagenet-200

## Create the package files (data.py, model.py, trainer.py, utils.py, train.py) in Colab
The notebook will write the same refactored modules into `/content/neural_network_analysis/` so the runner can import them.

In [None]:
package_dir = '/content/neural_network_analysis'
import os
import textwrap
os.makedirs(package_dir, exist_ok=True)

# Write a minimal __init__.py
with open(os.path.join(package_dir, '__init__.py'), 'w', encoding='utf-8') as f:
    f.write('# Package init for neural_network_analysis\n')

# Cleanly define module sources using triple-double-quoted strings and textwrap.dedent
files = {
    'data.py': textwrap.dedent("""
        import os
        import numpy as np
        from torchvision import datasets, transforms
        from torch.utils.data import DataLoader
        from albumentations.pytorch import ToTensorV2
        import albumentations as A


        def get_transforms(img_size=64):
            train_transforms = A.Compose([
                A.RandomResizedCrop(size=(img_size, img_size), p=1.0),
                A.HorizontalFlip(p=0.5),
                A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ToTensorV2(),
            ])

            def transform_wrapper(pil_img):
                image_np = np.array(pil_img)
                augmented = train_transforms(image=image_np)
                return augmented['image']

            test_transforms = transforms.Compose([
                transforms.Resize(img_size),
                transforms.CenterCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])

            return transform_wrapper, test_transforms


        def get_dataloaders(train_dir, val_dir, batch_size=64, img_size=64, num_workers=0, pin_memory=False):
            transform_train, transform_test = get_transforms(img_size=img_size)

            train_ds = datasets.ImageFolder(train_dir, transform=transform_train)
            val_ds = datasets.ImageFolder(val_dir, transform=transform_test)

            dataloader_args = dict(shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
            train_loader = DataLoader(train_ds, **dataloader_args)
            val_loader = DataLoader(val_ds, **dataloader_args)

            return train_loader, val_loader, train_ds
    """),

    'model.py': textwrap.dedent("""
        import torch.nn as nn
        from torchvision import models


        def build_resnet(model_name='resnet18', num_classes=200, pretrained=False):
            """Return a ResNet model with modified final layer."""
            model_map = {
                'resnet18': models.resnet18,
                'resnet34': models.resnet34,
                'resnet50': models.resnet50,
            }

            if model_name not in model_map:
                raise ValueError(f"Unsupported model_name: {model_name}")

            model = model_map[model_name](weights=None if not pretrained else 'IMAGENET1K_V1')
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, num_classes)
            return model
    """),

    'trainer.py': textwrap.dedent("""
        import torch
        import numpy as np
        import torch.nn.functional as F
        from tqdm import tqdm
        import os
        from huggingface_hub import HfApi


        def train_epoch(model, device, train_loader, optimizer, scheduler=None, cutmix_prob=0.0):
            """Run one training epoch.

            Returns:
                epoch_loss (float), epoch_acc (float), batch_losses (list), batch_accs (list)
            """
            model.train()
            correct = 0
            processed = 0
            batch_losses = []
            batch_accs = []

            pbar = tqdm(train_loader, desc='Train', leave=True)
            for batch_idx, (data, target) in enumerate(pbar):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                outputs = model(data)
                loss = F.cross_entropy(outputs, target)
                loss.backward()
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()

                pred = outputs.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                processed += data.size(0)
                batch_losses.append(loss.item())

                running_acc = 100. * correct / processed if processed > 0 else 0.0
                batch_accs.append(running_acc)

                # Mirror previous behavior: show loss, batch id, and running accuracy
                pbar.set_description(desc=f'Loss={loss.item():.4f} Batch_id={batch_idx} Accuracy={running_acc:0.2f}')

            epoch_loss = float(np.mean(batch_losses)) if batch_losses else 0.0
            epoch_acc = 100. * correct / processed if processed > 0 else 0.0
            return epoch_loss, epoch_acc, batch_losses, batch_accs


        def evaluate(model, device, loader):
            """Evaluate model on `loader`. Returns (test_loss, acc, batch_losses, batch_accs)."""
            model.eval()
            test_loss = 0
            correct = 0
            batch_losses = []
            batch_accs = []
            total = 0
            with torch.no_grad():
                pbar = tqdm(loader, desc='Eval', leave=True)
                for batch_idx, (data, target) in enumerate(pbar):
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    loss = F.cross_entropy(output, target, reduction='sum').item()
                    test_loss += loss
                    pred = output.argmax(dim=1)
                    correct += pred.eq(target).sum().item()
                    batch_size = data.size(0)
                    total += batch_size
                    batch_losses.append(loss / batch_size)
                    batch_accs.append(100. * pred.eq(target).sum().item() / batch_size)
                    pbar.set_description(desc=f'Eval Batch_id={batch_idx} Acc={100.*correct/total:0.2f}')

            test_loss /= len(loader.dataset)
            acc = 100. * correct / len(loader.dataset)
            return test_loss, acc, batch_losses, batch_accs


        def save_checkpoint_if_best(model, optimizer, epoch, test_acc, best_acc, model_dir, hf_username=None, model_name=None):
            os.makedirs(model_dir, exist_ok=True)
            if test_acc > best_acc:
                best_acc = test_acc
                filename = os.path.join(model_dir, 'pytorch_model.pt')
                torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'test_acc': test_acc}, filename)

                # Optionally upload to Hugging Face hub if credentials/config provided
                if hf_username and model_name and best_acc > 50.0:
                    repo_id = f"{hf_username}/{model_name}"
                    api = HfApi()
                    try:
                        api.upload_file(path_or_fileobj=filename, path_in_repo=os.path.basename(filename), repo_id=repo_id, repo_type='model')
                    except Exception:
                        pass

            return best_acc
    """),

    'utils.py': textwrap.dedent("""
        import torch
        import matplotlib.pyplot as plt
        import numpy as np


        def get_device():
            use_cuda = torch.cuda.is_available()
            device = torch.device('cuda' if use_cuda else 'cpu')
            if use_cuda:
                torch.backends.cudnn.benchmark = True
            return device


        def imshow_batch(img_batch, mean=None, std=None):
            img = img_batch.clone().cpu()
            npimg = img.numpy()
            if mean is None:
                mean = np.array([0.5071, 0.4867, 0.4408]).reshape(3,1,1)
            if std is None:
                std = np.array([0.2675, 0.2565, 0.2761]).reshape(3,1,1)
            npimg = (npimg * std) + mean
            npimg = np.transpose(npimg, (1, 2, 0))
            plt.imshow(npimg)
            plt.axis('off')
            plt.show()
    """),

    'train.py': textwrap.dedent("""
        """Modular training runner for tiny-imagenet using neural_network_analysis modules.

        This script is intentionally minimal â€” the heavy lifting lives in the
        `neural_network_analysis` package (data.py, model.py, trainer.py, utils.py).
        """
        import os
        import argparse
        import torch.optim as optim

        from neural_network_analysis import data, model as model_mod, trainer, utils


        def parse_args():
            p = argparse.ArgumentParser()
            p.add_argument('--train-dir', default='./tiny-imagenet-200/train')
            p.add_argument('--val-dir', default='./tiny-imagenet-200/val')
            p.add_argument('--model-dir', default='./saved_model')
            p.add_argument('--batch-size', type=int, default=128)
            p.add_argument('--img-size', type=int, default=64)
            p.add_argument('--num-workers', type=int, default=0)
            p.add_argument('--epochs', type=int, default=30)
            p.add_argument('--num-classes', type=int, default=200)
            return p.parse_args()


        def main():
            args = parse_args()
            train_dir = os.path.abspath(args.train_dir)
            val_dir = os.path.abspath(args.val_dir)
            model_dir = os.path.abspath(args.model_dir)
            os.makedirs(model_dir, exist_ok=True)

            device = utils.get_device()
            pin_memory = (device.type == 'cuda')

            train_loader, val_loader, train_ds = data.get_dataloaders(
                train_dir, val_dir, batch_size=args.batch_size, img_size=args.img_size, num_workers=args.num_workers, pin_memory=pin_memory
            )

            # Visualize a small sample from the training loader using the helper in utils
            try:
                from torchvision import utils as tv_utils
                dataiter = iter(train_loader)
                imgs, labels = next(dataiter)
                grid = tv_utils.make_grid(imgs[:4])
                utils.imshow_batch(grid)
            except Exception as e:
                # Don't crash if visualization backend is not available (e.g., headless)
                print('Batch visualization skipped:', e)

            model = model_mod.build_resnet('resnet18', num_classes=args.num_classes, pretrained=False)
            model = model.to(device)

            optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = None

            best_acc = 0.0
            # Metrics history (to match original script)
            train_losses = []
            test_losses = []
            train_acc = []
            test_acc = []
            train_losses_epoch = []

            for epoch in range(args.epochs):
                print(f'EPOCH: {epoch}')
                t_loss, t_acc, batch_losses, batch_accs = trainer.train_epoch(model, device, train_loader, optimizer, scheduler=scheduler)
                v_loss, v_acc, v_batch_losses, v_batch_accs = trainer.evaluate(model, device, val_loader)

                # Append per-batch data to global lists (original code appended per-batch for some lists)
                train_losses.extend(batch_losses)
                train_acc.extend(batch_accs)
                test_losses.extend(v_batch_losses)
                test_acc.extend(v_batch_accs)

                # Epoch level
                train_losses_epoch.append(t_loss)

                print(f'Epoch {epoch}: Train loss {t_loss:.4f}, Train acc {t_acc:.2f}%, Val loss {v_loss:.4f}, Val acc {v_acc:.2f}%')

                best_acc = trainer.save_checkpoint_if_best(model, optimizer, epoch, v_acc, best_acc, model_dir)


        if __name__ == '__main__':
            main()
    """)
}

# Write all files
for name, src in files.items():
    path = os.path.join(package_dir, name)
    with open(path, 'w', encoding='utf-8') as f:
        f.write(src)

print('Wrote neural_network_analysis package to', package_dir)


## Run the training runner (quick smoke test)
The command below runs the modular runner against Imagenette extracted under /content/imagenette/imagenette2-160. This is a smoke test and may take time depending on epochs and hardware.
Adjust `--batch-size` and `--epochs` for Colab (e.g., batch-size=64 epochs=1 for a quick run).

In [None]:
# Run the modular training script against Tiny-ImageNet (quick run)
!python -m neural_network_analysis.train --train-dir /content/tiny-imagenet-200/train --val-dir /content/tiny-imagenet-200/val --batch-size 64 --epochs 1 --model-dir /content/drive/MyDrive/saved_model

Notes:
- If you want to persist the saved model to your Google Drive, mount Drive and set the `--model-dir` to a folder under `/content/drive/MyDrive/`.
- If Colab runs out of RAM, lower `--batch-size` and `--num-workers`.
- If you prefer Tiny-ImageNet instead of Imagenette, replace the dataset download cell with a Tiny-ImageNet download (dataset is ~500MB).

If you want, I can also: add an explicit Drive-mount cell, or create a small `requirements.txt` cell to pin versions.