**This notebook will go over:**

- The Machine Learning Pipeline
  - 🟦Importing
  - 🟧Data Preprocessing
  - 🟩GAN Building
  - 🟥Training
  - 🟪Testing


---

<br>

**Guide to completing this project:**
- Codes (1,2,3,4,5) indicate where in the notebook / ML pipeline you are in
- Sections labeled with (⌛) may take longer amounts of time to code

Good luck!
Start by scrolling down.

Lets start by importing all our dependencies. Install them if they are not in your project already

In [1]:
import os 
import torch 
import torch.nn as nn
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary
from torchvision.transforms import transforms
import matplotlib
import matplotlib.pyplot as plt

We will use this `transforms.Compose()` method to help us make our dataset into tensors.

In [2]:
img_transform = transforms.Compose([
    transforms.ToTensor(),    
])

Make our `train_data` and `test_data`. We will be downloading the CIFAR-10 dataset to the root of `"./data/train"` or `"./data/test"`.

NOTE: When writing this out, set the `transform` parameter to our `img_transform.

In [None]:
#CODE HERE

Make a `train_loader` dataloader with a batch size of 32.

In [None]:
#CODE HERE

Use `iter()` and `next()` method to get the first item of our `train_loader`, then print the shape.

In [None]:
#CODE HERE

The code below will show the image 

In [None]:
image, label = next(iter(train_loader))
plt.imshow(image[0].permute(1, 2, 0))

#PROCEED TO THIS STEP AFTER COMPLETEING 3.gan_model.py
We have already trained a model and pushed it to hugging face for you
You will be pulling this model and using it locally to run inference below

In [None]:
from huggingface_hub import snapshot_download
import torch
from gan_model import ConditionalGenerator

# Download the entire repository
local_dir = snapshot_download(
    repo_id="sohumgautam/cifar_gan_model",
    local_dir="./downloaded_model"
)

# Load models from local directory
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = ConditionalGenerator(latent_dim=100, num_classes=10)
generator.load_state_dict(torch.load(f"{local_dir}/generator.pth", map_location=device))
generator.to(device)
generator.eval()

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create generator
generator = ConditionalGenerator(latent_dim=100, num_classes=10)

# Load checkpoint if available (otherwise will use untrained weights)
checkpoint_path = 'checkpoints/cgan_generator_epoch100.pth'  # Change to your path
generator.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False))

generator.to(device)
generator.eval()

# CIFAR-10 class names
cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                 'dog', 'frog', 'horse', 'ship', 'truck']

# Choose a class (0-9)
class_idx = 5  # Generate a cat (change to any number 0-9)
print(f"Generating a {cifar_classes[class_idx]}")

# Generate one image
with torch.no_grad():
    # Create random noise
    z = torch.randn(1, 100, device=device)
    
    # Create class label
    label = torch.tensor([class_idx], device=device)
    
    # Generate image
    fake_image = generator(z, label)
    
    # Convert from [-1,1] to [0,1] range
    fake_image = fake_image * 0.5 + 0.5

# Display image
plt.figure(figsize=(3, 3))
plt.imshow(fake_image[0].cpu().permute(1, 2, 0))
plt.title(f"Generated {cifar_classes[class_idx]}")
plt.axis('off')
plt.show()