<a href="https://colab.research.google.com/github/Shaurya-S0603/ArchAI/blob/main/ArchAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **TRAINING (GPU)**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import os
import gc
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from PIL import Image
import torch.nn.functional as F

%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

BATCH_SIZE = 8
IMG_SIZE = 64
LATENT_DIM = 100
EPOCHS = 100
LEARNING_RATE = 0.0002
ACCUM_STEPS = 2

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform):
        self.images = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        return self.transform(image), 0

dataset = ImageDataset("/content/drive/MyDrive/ArchAI_Dataset/train", transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = IMG_SIZE // 4
        self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128 * (IMG_SIZE // 4) * (IMG_SIZE // 4), 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

generator = Generator().to(device)
discriminator = Discriminator().to(device)

checkpoint_dir = "/content/drive/MyDrive/ArchAI/models"
os.makedirs(checkpoint_dir, exist_ok=True)

generator_path = f"{checkpoint_dir}/generator.pth"
discriminator_path = f"{checkpoint_dir}/discriminator.pth"

if os.path.exists(generator_path) and os.path.exists(discriminator_path):
    generator.load_state_dict(torch.load(generator_path, map_location=device))
    discriminator.load_state_dict(torch.load(discriminator_path, map_location=device))
    print("✅ Loaded latest checkpoint")

generator = torch.compile(generator)
discriminator = torch.compile(discriminator)

criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
scaler = torch.amp.GradScaler("cuda")

def train_gan(epochs=EPOCHS):
    for epoch in range(epochs):
        for i, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            optimizer_G.zero_grad()
            z = torch.randn(batch_size, LATENT_DIM).to(device)
            with torch.amp.autocast("cuda"):
                fake_images = generator(z)
                fake_output = discriminator(fake_images)
                loss_G = criterion(fake_output, real_labels) / ACCUM_STEPS
            scaler.scale(loss_G).backward()
            if (i + 1) % ACCUM_STEPS == 0:
                scaler.step(optimizer_G)
                scaler.update()
                optimizer_G.zero_grad()

            optimizer_D.zero_grad()
            with torch.amp.autocast("cuda"):
                real_output = discriminator(real_images)
                loss_real = criterion(real_output, real_labels)
                fake_output = discriminator(fake_images.detach())
                loss_fake = criterion(fake_output, fake_labels)
                loss_D = (loss_real + loss_fake) / (2 * ACCUM_STEPS)
            scaler.scale(loss_D).backward()
            if (i + 1) % ACCUM_STEPS == 0:
                scaler.step(optimizer_D)
                scaler.update()
                optimizer_D.zero_grad()

            if i % 500 == 0:
                print(f"Epoch [{epoch}/{epochs}] | Batch [{i}/{len(dataloader)}] | Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

            del real_images, fake_images, fake_output, real_output, loss_G, loss_D, loss_real, loss_fake, z
            torch.cuda.empty_cache()
            gc.collect()

        if epoch % 10 == 0:
            torch.save(generator.state_dict(), f"/content/drive/MyDrive/ArchAI/models/generator.pth")
            torch.save(discriminator.state_dict(), f"/content/drive/MyDrive/ArchAI/models/discriminator.pth")
            print(f"✅ Checkpoint saved at epoch {epoch}")

train_gan(epochs=EPOCHS)

torch.save(generator.state_dict(), "/content/drive/MyDrive/ArchAI/models/generator.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/ArchAI/models/discriminator.pth")
print("✅ Final models saved!")


# **TRAINING (CPU)**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import os
import gc
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from PIL import Image
import torch.nn.functional as F

device = torch.device("cpu")
print(f"✅ Using device: {device}")

BATCH_SIZE = 8
IMG_SIZE = 64
LATENT_DIM = 100
EPOCHS = 100
LEARNING_RATE = 0.0002

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform):
        self.images = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        return self.transform(image), 0

dataset = ImageDataset("/content/drive/MyDrive/ArchAI_Dataset/train", transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = IMG_SIZE // 4
        self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128 * (IMG_SIZE // 4) * (IMG_SIZE // 4), 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

generator = Generator().to(device)
discriminator = Discriminator().to(device)

checkpoint_dir = "/content/drive/MyDrive/ArchAI/models"
os.makedirs(checkpoint_dir, exist_ok=True)

generator_path = f"{checkpoint_dir}/generator_cpu.pth"
discriminator_path = f"{checkpoint_dir}/discriminator_cpu.pth"

if os.path.exists(generator_path) and os.path.exists(discriminator_path):
    generator.load_state_dict(torch.load(generator_path, map_location=torch.device('cpu')))
    discriminator.load_state_dict(torch.load(discriminator_path, map_location=torch.device('cpu')))
    print("✅ Loaded latest checkpoint")

criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

for epoch in range(EPOCHS):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to(device)
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        optimizer_G.zero_grad()
        z = torch.randn(batch_size, LATENT_DIM).to(device)
        fake_images = generator(z)
        fake_output = discriminator(fake_images)
        loss_G = criterion(fake_output, real_labels)
        loss_G.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        real_output = discriminator(real_images)
        loss_real = criterion(real_output, real_labels)
        fake_output = discriminator(fake_images.detach())
        loss_fake = criterion(fake_output, fake_labels)
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        optimizer_D.step()

        if i % 500 == 0:
            print(f"Epoch [{epoch}/{EPOCHS}] | Batch [{i}/{len(dataloader)}] | Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

        del real_images, fake_images, fake_output, real_output, loss_G, loss_D, loss_real, loss_fake, z
        gc.collect()

    if epoch % 10 == 0:
        torch.save(generator.state_dict(), f"/content/drive/MyDrive/ArchAI/models/generator_cpu.pth")
        torch.save(discriminator.state_dict(), f"/content/drive/MyDrive/ArchAI/models/discriminator_cpu.pth")
        print(f"✅ Checkpoint saved at epoch {epoch}")

torch.save(generator.state_dict(), "/content/drive/MyDrive/ArchAI/models/generator_cpu.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/ArchAI/models/discriminator_cpu.pth")
print("✅ Final models saved!")


# **TESTING**

In [None]:
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

generator = Generator().to(device)

load = int(input("Trained on CPU or GPU? (1,2): "))
if load == 1:
  generator.load_state_dict(torch.load("/content/drive/MyDrive/ArchAI/models/generator_cpu.pth", map_location=torch.device('cpu')))
else:
  generator.load_state_dict(torch.load("/content/drive/MyDrive/ArchAI/models/generator.pth"))
generator.eval()

def generate_images(num_images=5):
    z = torch.randn(num_images, LATENT_DIM).to(device)
    with torch.no_grad():
        fake_images = generator(z).cpu()

    fake_images = (fake_images + 1) / 2

    fig, axes = plt.subplots(1, num_images, figsize=(15, 15))
    for i in range(num_images):
        img = transforms.ToPILImage()(fake_images[i])
        axes[i].imshow(img)
        axes[i].axis("off")
    plt.show()
    print()

generate_images()
