In [6]:
!pip install torch torchvision torchaudio
!pip install wandb huggingface_hub gradio




In [7]:
import wandb
from huggingface_hub import notebook_login

# W&B Login
wandb.login()

# Hugging Face Login (enter token)
notebook_login()


  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mraghav051006[0m ([33mraghav051006-sage-university-ind[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [8]:
%%writefile model.py
import torch.nn as nn
import torch.nn.functional as F

# Simplified AmoebaNet cell
class AmoebaCell(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AmoebaCell, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        return x

class AmoebaNetSmall(nn.Module):
    def __init__(self, num_classes=10):
        super(AmoebaNetSmall, self).__init__()
        self.stem = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.cells = nn.ModuleList([AmoebaCell(64, 64) for _ in range(6)])
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.stem(x)
        for cell in self.cells:
            x = cell(x)
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)


Writing model.py


In [9]:
%%writefile train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import wandb
from model import AmoebaNetSmall
from huggingface_hub import HfApi, HfFolder, Repository
import os

# -----------------------------
# Config
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
EPOCHS = 10
LR = 0.01
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

wandb.init(project="gpipe-amoeba-cifar10", config={
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "learning_rate": LR
})

# -----------------------------
# CIFAR-10 Dataset
# -----------------------------
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE,
                         shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=2)

# -----------------------------
# Model + Optimizer
# -----------------------------
model = AmoebaNetSmall(num_classes=10).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)

# -----------------------------
# Training & Testing
# -----------------------------
def train():
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        wandb.log({"train_loss": running_loss/len(trainloader)})
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader):.3f}")

        if (epoch+1) % 2 == 0:
            ckpt_path = f"{CHECKPOINT_DIR}/epoch_{epoch+1}.pth"
            torch.save(model.state_dict(), ckpt_path)
            print(f"✅ Saved checkpoint at {ckpt_path}")

def test():
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    acc = 100.*correct/total
    wandb.log({"test_accuracy": acc})
    print(f"Test Accuracy: {acc:.2f}%")

if __name__ == "__main__":
    train()
    test()
    torch.save(model.state_dict(), "final_model.pth")
    print("✅ Final model saved as final_model.pth")

    # Push to Hugging Face Hub
    repo_id = "Raghav81/gpipe-amoeba-cifar10"
    api = HfApi()
    token = HfFolder.get_token()
    repo = Repository(local_dir="hf_repo", clone_from=repo_id, use_auth_token=token)
    os.system("cp final_model.pth hf_repo/")
    os.system("cd hf_repo && git add . && git commit -m 'Add trained model' && git push")

Overwriting train.py


In [None]:
!python train.py


[34m[1mwandb[0m: Currently logged in as: [33mraghav051006[0m ([33mraghav051006-sage-university-ind[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.21.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/wandb/run-20250829_113731-duh78aih[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mgentle-disco-3[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/raghav051006-sage-university-ind/gpipe-amoeba-cifar10[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/raghav051006-sage-university-ind/gpipe-amoeba-cifar10/runs/duh78aih[0m
100% 170M/170M [00:01<00:00, 99.8MB/s]


In [None]:
!python train.py


In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from model import AmoebaNetSmall

# Load trained model
model = AmoebaNetSmall(num_classes=10)
model.load_state_dict(torch.load("final_model.pth", map_location="cpu"))
model.eval()

classes = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']

transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

def predict(img):
    img = transform(img).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img)
        _, pred = outputs.max(1)
    return {classes[i]: float(outputs[0][i]) for i in range(10)}

gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3)).launch()
