In [1]:
import helpMe
import wandb
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
from PIL import Image, ImageFilter

from torch.utils.data import DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt

import torchvision.transforms.functional as TF

device = helpMe.get_default_device()
# device = 'cpu'

## Configrations

In [2]:
model_name = "CUNet_Lapgan_EMNIST"
image_size = 32
batch_size = 64
# z_dim = 128
# DATA_DIR = './imageNet_lp/torch_image_folder/mnt/volume_sfo3_01/imagenet-lt/ImageDataset/train'
# stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
channels =1
epochs = 110

In [3]:
transforms = T.Compose([
    T.Resize(32),
    T.RandomRotation((-90,-90)),
    T.RandomHorizontalFlip(1),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])
dataset = Datasets.EMNIST(root='./Datasxts/EMNIST/', split='byclass',train =True, download=True,transform=transforms)

In [4]:
dataset

Dataset EMNIST
    Number of datapoints: 697932
    Root location: ./Datasxts/EMNIST/
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=32, interpolation=bilinear, max_size=None, antialias=True)
               RandomRotation(degrees=[-90.0, -90.0], interpolation=nearest, expand=False, fill=0)
               RandomHorizontalFlip(p=1)
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

In [5]:
num_classes= len(dataset.classes)

In [6]:
num_classes

62

In [7]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
import wandb  # Import wandb

# Ensure you have all necessary imports for your UNet and Discriminator models
import Unet  # Assuming you have a file named Unet.py with relevant model definitions

# # Initialize gradient scalers for mixed precision training
# scaler_G_32 = GradScaler()
# scaler_G_16 = GradScaler()
# scaler_G_8 = GradScaler()
# scaler_D_32 = GradScaler()
# scaler_D_16 = GradScaler()
# scaler_D_8 = GradScaler()

class LAPGAN(nn.Module):
    def __init__(self, generators, save_path):
        super(LAPGAN, self).__init__()
        self.G_32, self.G_16, self.G_8 = generators
        self.save_path = save_path
        
    def forward(self, z, y, epoch, i):
        # Generate the smallest scale image (8x8)
        x = self.G_8(z, y)
        helpMe.save_generated_images(genH_realH=x,recon=None, epoch=epoch, i=i, path=self.save_path,res='8_before', a='LAPgan/')
        # wandb.log({"Generated Images": [wandb.Image(fake_images, caption=f"Epoch {epoch}")]})
        self.log_generated_images(x, epoch, i, res='8_before')

        # Upscale to 16x16
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        
        helpMe.save_generated_images(genH_realH=x,recon=None, epoch=epoch, i=i, path=self.save_path,res='8_after', a='LAPgan/')   
        self.log_generated_images(x, epoch, i, res='8_after')

        # Refine with next generator
        xH = self.G_16(x, y)
        x = x + xH
        
        helpMe.save_generated_images(genH_realH=x,recon=None, epoch=epoch, i=i, path=self.save_path,res='16', a='LAPgan/')
        self.log_generated_images(x, epoch, i, res='16')
        

        # Upscale to 32x32
        x = F.interpolate(x, scale_factor=2, mode='bilinear')

        # Refine with the final generator
        xH = self.G_32(x, y)
        x = x + xH
        
        helpMe.save_generated_images(genH_realH=x,recon=None, epoch=epoch, i=i, path=self.save_path,res='32', a='LAPgan/')
        self.log_generated_images(x, epoch, i, res='32')

        return x

    def log_generated_images(self, images, epoch, i, res):
        images_cpu = images.detach().cpu()
        grid = torchvision.utils.make_grid(images_cpu, normalize=True, scale_each=True, nrow=8)
        wandb.log({f"Generated Images {res}": [wandb.Image(grid, caption=f"Epoch: {epoch}")]})





import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import wandb

def save_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, epoch, checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint.pth')
    torch.save({
        'epoch': epoch,
        'G_32_state_dict': generators[0].state_dict(),
        'D_32_state_dict': discriminators[0].state_dict(),
        'G_16_state_dict': generators[1].state_dict(),
        'D_16_state_dict': discriminators[1].state_dict(),
        'G_8_state_dict': generators[2].state_dict(),
        'D_8_state_dict': discriminators[2].state_dict(),
        'opt_G_32_state_dict': opt_G_32.state_dict(),
        'opt_D_32_state_dict': opt_D_32.state_dict(),
        'opt_G_16_state_dict': opt_G_16.state_dict(),
        'opt_D_16_state_dict': opt_D_16.state_dict(),
        'opt_G_8_state_dict': opt_G_8.state_dict(),
        'opt_D_8_state_dict': opt_D_8.state_dict()
    }, checkpoint_path)

def load_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, checkpoint_dir, device):
    checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth')
    print(checkpoint_path)
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generators[0].load_state_dict(checkpoint['G_32_state_dict'])
        discriminators[0].load_state_dict(checkpoint['D_32_state_dict'])
        generators[1].load_state_dict(checkpoint['G_16_state_dict'])
        discriminators[1].load_state_dict(checkpoint['D_16_state_dict'])
        generators[2].load_state_dict(checkpoint['G_8_state_dict'])
        discriminators[2].load_state_dict(checkpoint['D_8_state_dict'])
        opt_G_32.load_state_dict(checkpoint['opt_G_32_state_dict'])
        opt_D_32.load_state_dict(checkpoint['opt_D_32_state_dict'])
        opt_G_16.load_state_dict(checkpoint['opt_G_16_state_dict'])
        opt_D_16.load_state_dict(checkpoint['opt_D_16_state_dict'])
        opt_G_8.load_state_dict(checkpoint['opt_G_8_state_dict'])
        opt_D_8.load_state_dict(checkpoint['opt_D_8_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}.")
    else:
        start_epoch = 1
        print("Starting training from scratch.")
    return start_epoch


def train_gan(generators, discriminators, dataloader, num_epochs, batch_size, noise_dim=100, checkpoint_dir=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = 'cpu'
    G_32, G_16, G_8 = [gen.to(device) for gen in generators]
    D_32, D_16, D_8 = [disc.to(device) for disc in discriminators]

    # Optimizers
    opt_G_32 = optim.Adam(G_32.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_D_32 = optim.Adam(D_32.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_G_16 = optim.Adam(G_16.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_D_16 = optim.Adam(D_16.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_G_8 = optim.Adam(G_8.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_D_8 = optim.Adam(D_8.parameters(), lr=2e-4, betas=(0.5, 0.999))

    # Load checkpoint if it exists
    if checkpoint_dir:
        start_epoch = load_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, checkpoint_dir, device)
    else:
        start_epoch = 1

    # Loss functions
    criterion = nn.BCELoss()
    l1_loss = nn.L1Loss()

    # Initialize wandb
    wandb.init(project="LAPGAN_EMNIST", config={
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "learning_rate": 2e-4,
        "betas": (0.5, 0.999)
    })

    for epoch in range(start_epoch, num_epochs + 1):
        total_d32_loss = 0.0
        total_g32_loss = 0.0
        total_d16_loss = 0.0
        total_g16_loss = 0.0
        total_d8_loss = 0.0
        total_g8_loss = 0.0

        with tqdm(enumerate(dataloader), total=len(dataloader)) as t:
            for i, (images, labels) in t:
                images_8 = F.interpolate(images, scale_factor=0.25, mode='bilinear')

                # for j in range(5):
                # perm = torch.randperm(images_8.size(0))
                images_8 = images_8
                labels = labels

                real = images_8.to(device)
                labels = labels.to(device)
                noise = torch.randn(batch_size, noise_dim, device=device)
                fake = G_8(noise, labels)

                # Train Discriminator(D_8)
                D_8_real = D_8(real, labels).view(-1)
                loss_D_8_real = criterion(D_8_real, torch.ones_like(D_8_real))
                D_8_fake = D_8(fake.detach(), labels).view(-1)
                loss_D_8_fake = criterion(D_8_fake, torch.zeros_like(D_8_fake))
                loss_D_8 = (loss_D_8_real + loss_D_8_fake) / 2
                opt_D_8.zero_grad()
                loss_D_8.backward()
                opt_D_8.step()

                total_d8_loss += loss_D_8.item()

                # Train Generator for 8x8 images
                output = D_8(fake, labels).view(-1)
                loss_G_8 = criterion(output, torch.ones_like(output))

                opt_G_8.zero_grad()
                loss_G_8.backward()
                opt_G_8.step()

                total_g8_loss += loss_G_8.item()

                if i % 100 == 0:
                    helpMe.save_generated_images(genH_realH=torch.cat([real, fake], dim=0), recon=None, epoch=epoch, i=i,
                                                 path=checkpoint_dir, res='8', a='')
                
                
                
                
                
                images_G8 = F.interpolate(fake, scale_factor=2, mode='bilinear').detach()
                
                del images_8, noise, fake, D_8_real, D_8_fake, loss_D_8_real, loss_D_8_fake, output
                
                images_16 = F.interpolate(images, scale_factor=0.5, mode='bilinear')
                _, real_high_freqs = helpMe.to_gaus(imgs=images_16)
                real_high_freqs = real_high_freqs.to(device)

                # Train Discriminator(D_16)
                opt_D_16.zero_grad()
                output_real_16 = D_16(real_high_freqs, labels).view(-1)
                loss_disc_real_16 = criterion(output_real_16, torch.ones_like(output_real_16))
                generated_high_freqs_16 = G_16(images_G8, labels)
                output_fake_16 = D_16(generated_high_freqs_16.detach(), labels).view(-1)
                loss_disc_fake_16 = criterion(output_fake_16, torch.zeros_like(output_fake_16))
                loss_disc_16 = (loss_disc_real_16 + loss_disc_fake_16) / 2
                loss_disc_16.backward()
                opt_D_16.step()

                total_d16_loss += loss_disc_16.item()

                # Train Generator for 16x16 images
                opt_G_16.zero_grad()
                
                # generated_high_freqs_16_1 = G_16(images_G8)
                output_fake_16 = D_16(generated_high_freqs_16, labels).view(-1)
                adv_loss_16 = criterion(output_fake_16, torch.ones_like(output_fake_16))
                l1_loss_16 = l1_loss(generated_high_freqs_16, real_high_freqs)
                loss_gen_16 = adv_loss_16 
                loss_gen_16.backward()
                
                opt_G_16.step()

                total_g16_loss += loss_gen_16.item()

                recon_imgs = images_G8 + generated_high_freqs_16
                if i % 100 == 0:
                    helpMe.save_generated_images(genH_realH=torch.cat([real_high_freqs, generated_high_freqs_16], dim=0),
                                                 recon=torch.cat([images_G8, recon_imgs], dim=0), epoch=epoch, i=i,
                                                 path=checkpoint_dir, res='16', a='')
                
                
                
                images_G16 = F.interpolate(recon_imgs, scale_factor=2, mode='bilinear').detach()
                del images_16, real_high_freqs, generated_high_freqs_16, output_real_16, output_fake_16, loss_disc_real_16, loss_disc_fake_16, adv_loss_16, l1_loss_16, recon_imgs
               
               
                _, real_high_freqs = helpMe.to_gaus(imgs=images)
                real_high_freqs = real_high_freqs.to(device)
                
                

                # Train Discriminator (D_32)
                opt_D_32.zero_grad()
                output_real_32 = D_32(real_high_freqs, labels).view(-1)
                loss_disc_real_32 = criterion(output_real_32, torch.ones_like(output_real_32))
                generated_high_freqs_32 = G_32(images_G16, labels)
                output_fake_32 = D_32(generated_high_freqs_32.detach(), labels).view(-1)
                loss_disc_fake_32 = criterion(output_fake_32, torch.zeros_like(output_fake_32))
                loss_disc_32 = (loss_disc_real_32 + loss_disc_fake_32) / 2
                loss_disc_32.backward(retain_graph=True)
                opt_D_32.step()

                total_d32_loss += loss_disc_32.item()

                # Train Generator for 32x32 images
                opt_G_32.zero_grad()
                output_fake_32 = D_32(generated_high_freqs_32, labels).view(-1)
                adv_loss_32 = criterion(output_fake_32, torch.ones_like(output_fake_32))
                l1_loss_32 = l1_loss(generated_high_freqs_32, real_high_freqs)
                loss_gen_32 = adv_loss_32 
                loss_gen_32.backward()
                opt_G_32.step()

                total_g32_loss += loss_gen_32.item()

                recon_imgs = images_G16 + generated_high_freqs_32
                if i % 100 == 0:
                    helpMe.save_generated_images(genH_realH=torch.cat([real_high_freqs, generated_high_freqs_32], dim=0),
                                                 recon=torch.cat([images_G16, recon_imgs], dim=0), epoch=epoch, i=i,
                                                 path=checkpoint_dir, res='32', a='')
                    
                del images, labels, real_high_freqs, generated_high_freqs_32, output_real_32, output_fake_32, loss_disc_real_32, loss_disc_fake_32, adv_loss_32, l1_loss_32, recon_imgs
                
                
                t.set_description(f'Epoch [{epoch}/{num_epochs}]')
                t.set_postfix({
                    'D32_loss': f'{loss_disc_32:.3f}',
                    'G32_loss': f'{loss_gen_32:.3f}',
                    'D16_loss': f'{loss_disc_16:.3f}',
                    'G16_loss': f'{loss_gen_16:.3f}',
                    'D8_loss': f'{loss_D_8:.3f}',
                    'G8_loss': f'{loss_G_8:.3f}'
                })

                wandb.log({
                    'D32_loss': loss_disc_32.item(),
                    'G32_loss': loss_gen_32.item(),
                    'D16_loss': loss_disc_16.item(),
                    'G16_loss': loss_gen_16.item(),
                    'D8_loss': loss_D_8.item(),
                    'G8_loss': loss_G_8.item()
                })
                
                
                
                torch.cuda.empty_cache()
                

        avg_d32_loss = total_d32_loss / len(dataloader)
        avg_g32_loss = total_g32_loss / len(dataloader)
        avg_d16_loss = total_d16_loss / len(dataloader)
        avg_g16_loss = total_g16_loss / len(dataloader)
        avg_d8_loss = total_d8_loss / (len(dataloader)*5)
        avg_g8_loss = total_g8_loss / (len(dataloader)*5)
        print(f"Epoch {epoch}: Loss D32: {avg_d32_loss:.4f}, Loss G32: {avg_g32_loss:.4f}, Loss D16: {avg_d16_loss:.4f}, Loss G16: {avg_g16_loss:.4f}, Loss D8: {avg_d8_loss:.4f}, Loss G8: {avg_g8_loss:.4f}")

        # save_path = 'path/to/save/images'
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        # device = 'cpu'
        lapgan_model = LAPGAN(generators, checkpoint_dir)
        lapgan_model.to(device)

        
        fixed_noise = torch.randn(num_classes, noise_dim, device=device)
        fixed_labels = torch.arange(0, num_classes, dtype=torch.long, device=device)
        # noise = torch.randn(batch_size, noise_dim, device=device)
        lapgan_model(fixed_noise, fixed_labels, epoch, i)

        save_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, epoch, checkpoint_dir)

# Assuming dataset and other parameters are properly defined elsewhere
gen32 = Unet.CUNetGenerator(num_classes = num_classes)
disc32 = Unet.CDiscriminator(num_classes = num_classes)
gen16 = Unet.CUNetGenerator(image_size = 16, features=[64, 128, 256],num_classes = num_classes)
disc16 = Unet.CDiscriminator(image_size = 16, features=[64, 128, 256], num_classes = num_classes)
gen8 = Unet.Generator_L(num_classes=num_classes)
disc8 = Unet.Discriminator_L(num_classes=num_classes)



dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)

generators = (gen32, gen16, gen8)
discriminators = (disc32, disc16, disc8)

checkpoint_dir = f"Models/{model_name}3/"

train_gan(generators, discriminators, dataloader, num_epochs=epochs, batch_size=batch_size, checkpoint_dir=checkpoint_dir)

wandb.finish()  


Models/CUNet_Lapgan_EMNIST3/checkpoint.pth


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Resuming training from epoch 3.


KeyboardInterrupt: 