In [4]:
import helpMe
import wandb
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
from PIL import Image, ImageFilter

from torch.utils.data import DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as T
import torch.nn.functional as F


device = helpMe.get_default_device()
# device = 'cpu'

## Configrations

In [5]:
model_name = "UNet_Lapgan"
image_size = 32
batch_size = 32
# z_dim = 128
# DATA_DIR = './imageNet_lp/torch_image_folder/mnt/volume_sfo3_01/imagenet-lt/ImageDataset/train'
# stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
channels =1
epochs = 110

In [6]:
transforms = T.Compose([
    T.Resize(32),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])
dataset = Datasets.CIFAR10(root='./Datasxts/CIFAR10/', train=True, download=True,transform=transforms)

Files already downloaded and verified


# Create GEN AND DISC


               if i % 100 == 0:
                    recon_imgs = smoothed_images + generated_high_freqs
                    helpMe.save_generated_images(torch.cat([real_high_freqs,generated_high_freqs],dim=0),torch.cat([images.to(device),recon_imgs],dim=0), epoch,i, checkpoint_dir, device)

In [8]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
import wandb  # Import wandb

# Ensure you have all necessary imports for your UNet and Discriminator models
import Unet  # Assuming you have a file named Unet.py with relevant model definitions

# # Initialize gradient scalers for mixed precision training
# scaler_G_32 = GradScaler()
# scaler_G_16 = GradScaler()
# scaler_G_8 = GradScaler()
# scaler_D_32 = GradScaler()
# scaler_D_16 = GradScaler()
# scaler_D_8 = GradScaler()

class LAPGAN(nn.Module):
    def __init__(self, generators, discriminators, save_path, device):
        super(LAPGAN, self).__init__()
        self.G_32, self.G_16, self.G_8 = generators
        self.save_path = save_path
        self.device = device
        
    def forward(self, z, y, epoch, i):
        # Generate the smallest scale image (8x8)
        x = self.G_8(z, y)
        helpMe.save_generated_images(genH_realH=x,recon=None, epoch=epoch, i=i, path=self.save_path,res='8_before', a='LAPgan/')
        # wandb.log({"Generated Images": [wandb.Image(fake_images, caption=f"Epoch {epoch}")]})
        self.log_generated_images(x, epoch, i, res='8_before')

        # Upscale to 16x16
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        
        helpMe.save_generated_images(genH_realH=x,recon=None, epoch=epoch, i=i, path=self.save_path,res='8_after', a='LAPgan/')   
        self.log_generated_images(x, epoch, i, res='8_after')

        # Refine with next generator
        xH = self.G_16(x)
        x = x + xH
        
        helpMe.save_generated_images(genH_realH=x,recon=None, epoch=epoch, i=i, path=self.save_path,res='16', a='LAPgan/')
        self.log_generated_images(x, epoch, i, res='16')
        

        # Upscale to 32x32
        x = F.interpolate(x, scale_factor=2, mode='bilinear')

        # Refine with the final generator
        xH = self.G_32(x)
        x = x + xH
        
        helpMe.save_generated_images(genH_realH=x,recon=None, epoch=epoch, i=i, path=self.save_path,res='32', a='LAPgan/')
        self.log_generated_images(x, epoch, i, res='32')

        return x

    def log_generated_images(self, images, epoch, i, res):
        images_cpu = images.detach().cpu()
        grid = torchvision.utils.make_grid(images_cpu, normalize=True, scale_each=True, nrow=8)
        wandb.log({f"Generated Images {res}": [wandb.Image(grid, caption=f"Epoch: {epoch}")]})



def load_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, checkpoint_dir, device):
    checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth')
    print(checkpoint_path)
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generators[0].load_state_dict(checkpoint['G_32_state_dict'])
        discriminators[0].load_state_dict(checkpoint['D_32_state_dict'])
        generators[1].load_state_dict(checkpoint['G_16_state_dict'])
        discriminators[1].load_state_dict(checkpoint['D_16_state_dict'])
        generators[2].load_state_dict(checkpoint['G_8_state_dict'])
        discriminators[2].load_state_dict(checkpoint['D_8_state_dict'])
        opt_G_32.load_state_dict(checkpoint['opt_G_32_state_dict'])
        opt_D_32.load_state_dict(checkpoint['opt_D_32_state_dict'])
        opt_G_16.load_state_dict(checkpoint['opt_G_16_state_dict'])
        opt_D_16.load_state_dict(checkpoint['opt_D_16_state_dict'])
        opt_G_8.load_state_dict(checkpoint['opt_G_8_state_dict'])
        opt_D_8.load_state_dict(checkpoint['opt_D_8_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}.")
    else:
        start_epoch = 1
        print("Starting training from scratch.")
    return start_epoch

def save_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, epoch, checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint.pth')
    torch.save({
        'epoch': epoch,
        'G_32_state_dict': generators[0].state_dict(),
        'D_32_state_dict': discriminators[0].state_dict(),
        'G_16_state_dict': generators[1].state_dict(),
        'D_16_state_dict': discriminators[1].state_dict(),
        'G_8_state_dict': generators[2].state_dict(),
        'D_8_state_dict': discriminators[2].state_dict(),
        'opt_G_32_state_dict': opt_G_32.state_dict(),
        'opt_D_32_state_dict': opt_D_32.state_dict(),
        'opt_G_16_state_dict': opt_G_16.state_dict(),
        'opt_D_16_state_dict': opt_D_16.state_dict(),
        'opt_G_8_state_dict': opt_G_8.state_dict(),
        'opt_D_8_state_dict': opt_D_8.state_dict()
    }, checkpoint_path)

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import wandb

def save_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, epoch, checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint.pth')
    torch.save({
        'epoch': epoch,
        'G_32_state_dict': generators[0].state_dict(),
        'D_32_state_dict': discriminators[0].state_dict(),
        'G_16_state_dict': generators[1].state_dict(),
        'D_16_state_dict': discriminators[1].state_dict(),
        'G_8_state_dict': generators[2].state_dict(),
        'D_8_state_dict': discriminators[2].state_dict(),
        'opt_G_32_state_dict': opt_G_32.state_dict(),
        'opt_D_32_state_dict': opt_D_32.state_dict(),
        'opt_G_16_state_dict': opt_G_16.state_dict(),
        'opt_D_16_state_dict': opt_D_16.state_dict(),
        'opt_G_8_state_dict': opt_G_8.state_dict(),
        'opt_D_8_state_dict': opt_D_8.state_dict()
    }, checkpoint_path)

def load_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, checkpoint_dir, device):
    checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth')
    print(checkpoint_path)
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generators[0].load_state_dict(checkpoint['G_32_state_dict'])
        discriminators[0].load_state_dict(checkpoint['D_32_state_dict'])
        generators[1].load_state_dict(checkpoint['G_16_state_dict'])
        discriminators[1].load_state_dict(checkpoint['D_16_state_dict'])
        generators[2].load_state_dict(checkpoint['G_8_state_dict'])
        discriminators[2].load_state_dict(checkpoint['D_8_state_dict'])
        opt_G_32.load_state_dict(checkpoint['opt_G_32_state_dict'])
        opt_D_32.load_state_dict(checkpoint['opt_D_32_state_dict'])
        opt_G_16.load_state_dict(checkpoint['opt_G_16_state_dict'])
        opt_D_16.load_state_dict(checkpoint['opt_D_16_state_dict'])
        opt_G_8.load_state_dict(checkpoint['opt_G_8_state_dict'])
        opt_D_8.load_state_dict(checkpoint['opt_D_8_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}.")
    else:
        start_epoch = 1
        print("Starting training from scratch.")
    return start_epoch


def train_gan(generators, discriminators, dataloader, num_epochs, batch_size, noise_dim=100, checkpoint_dir=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = 'cpu'
    G_32, G_16, G_8 = [gen.to(device) for gen in generators]
    D_32, D_16, D_8 = [disc.to(device) for disc in discriminators]

    # Optimizers
    opt_G_32 = optim.Adam(G_32.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_D_32 = optim.Adam(D_32.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_G_16 = optim.Adam(G_16.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_D_16 = optim.Adam(D_16.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_G_8 = optim.Adam(G_8.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_D_8 = optim.Adam(D_8.parameters(), lr=2e-4, betas=(0.5, 0.999))

    # Load checkpoint if it exists
    if checkpoint_dir:
        start_epoch = load_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, checkpoint_dir, device)
    else:
        start_epoch = 1

    # Loss functions
    criterion = nn.BCELoss()
    l1_loss = nn.L1Loss()

    # Initialize wandb
    wandb.init(project="LAPGAN", config={
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "learning_rate": 2e-4,
        "betas": (0.5, 0.999)
    })

    for epoch in range(start_epoch, num_epochs + 1):
        total_d32_loss = 0.0
        total_g32_loss = 0.0
        total_d16_loss = 0.0
        total_g16_loss = 0.0
        total_d8_loss = 0.0
        total_g8_loss = 0.0

        with tqdm(enumerate(dataloader), total=len(dataloader)) as t:
            for i, (images, labels) in t:
                images_8 = F.interpolate(images, scale_factor=0.25, mode='bilinear')

                for j in range(5):
                    perm = torch.randperm(images_8.size(0))
                    images_8 = images_8[perm]
                    labels = labels[perm]

                    real = images_8.to(device)
                    labels = labels.to(device)
                    noise = torch.randn(batch_size, noise_dim, device=device)
                    fake = G_8(noise, labels)

                    # Train Discriminator(D_8)
                    D_8_real = D_8(real, labels).view(-1)
                    loss_D_8_real = criterion(D_8_real, torch.ones_like(D_8_real))
                    D_8_fake = D_8(fake.detach(), labels).view(-1)
                    loss_D_8_fake = criterion(D_8_fake, torch.zeros_like(D_8_fake))
                    loss_D_8 = (loss_D_8_real + loss_D_8_fake) / 2
                    opt_D_8.zero_grad()
                    loss_D_8.backward()
                    opt_D_8.step()

                    total_d8_loss += loss_D_8.item()

                    # Train Generator for 8x8 images
                    output = D_8(fake, labels).view(-1)
                    loss_G_8 = criterion(output, torch.ones_like(output))

                    opt_G_8.zero_grad()
                    loss_G_8.backward()
                    opt_G_8.step()

                    total_g8_loss += loss_G_8.item()

                if i % 100 == 0:
                    helpMe.save_generated_images(genH_realH=torch.cat([real, fake], dim=0), recon=None, epoch=epoch, i=i,
                                                 path=checkpoint_dir, res='8', a='')
                
                
                
                
                
                images_G8 = F.interpolate(fake, scale_factor=2, mode='bilinear').detach()
                
                del images_8, labels, noise, fake, D_8_real, D_8_fake, loss_D_8_real, loss_D_8_fake, output
                
                images_16 = F.interpolate(images, scale_factor=0.5, mode='bilinear')[perm]
                _, real_high_freqs = helpMe.to_gaus(imgs=images_16)
                real_high_freqs = real_high_freqs.to(device)

                # Train Discriminator(D_16)
                opt_D_16.zero_grad()
                output_real_16 = D_16(real_high_freqs).view(-1)
                loss_disc_real_16 = criterion(output_real_16, torch.ones_like(output_real_16))
                generated_high_freqs_16 = G_16(images_G8)
                output_fake_16 = D_16(generated_high_freqs_16.detach()).view(-1)
                loss_disc_fake_16 = criterion(output_fake_16, torch.zeros_like(output_fake_16))
                loss_disc_16 = (loss_disc_real_16 + loss_disc_fake_16) / 2
                loss_disc_16.backward()
                opt_D_16.step()

                total_d16_loss += loss_disc_16.item()

                # Train Generator for 16x16 images
                opt_G_16.zero_grad()
                
                # generated_high_freqs_16_1 = G_16(images_G8)
                output_fake_16 = D_16(generated_high_freqs_16).view(-1)
                adv_loss_16 = criterion(output_fake_16, torch.ones_like(output_fake_16))
                l1_loss_16 = l1_loss(generated_high_freqs_16, real_high_freqs)
                loss_gen_16 = adv_loss_16 + l1_loss_16
                loss_gen_16.backward()
                
                opt_G_16.step()

                total_g16_loss += loss_gen_16.item()

                recon_imgs = images_G8 + generated_high_freqs_16
                if i % 100 == 0:
                    helpMe.save_generated_images(genH_realH=torch.cat([real_high_freqs, generated_high_freqs_16], dim=0),
                                                 recon=torch.cat([images_G8, recon_imgs], dim=0), epoch=epoch, i=i,
                                                 path=checkpoint_dir, res='16', a='')
                
                
                
                images_G16 = F.interpolate(recon_imgs, scale_factor=2, mode='bilinear').detach()
                del images_16, real_high_freqs, generated_high_freqs_16, output_real_16, output_fake_16, loss_disc_real_16, loss_disc_fake_16, adv_loss_16, l1_loss_16, recon_imgs
               
               
                _, real_high_freqs = helpMe.to_gaus(imgs=images)
                real_high_freqs = real_high_freqs.to(device)[perm]
                
                

                # Train Discriminator (D_32)
                opt_D_32.zero_grad()
                output_real_32 = D_32(real_high_freqs).view(-1)
                loss_disc_real_32 = criterion(output_real_32, torch.ones_like(output_real_32))
                generated_high_freqs_32 = G_32(images_G16)
                # print("gen", generated_high_freqs_32.shape)
                output_fake_32 = D_32(generated_high_freqs_32.detach()).view(-1)
                loss_disc_fake_32 = criterion(output_fake_32, torch.zeros_like(output_fake_32))
                loss_disc_32 = (loss_disc_real_32 + loss_disc_fake_32) / 2
                loss_disc_32.backward(retain_graph=True)
                opt_D_32.step()

                total_d32_loss += loss_disc_32.item()

                # Train Generator for 32x32 images
                opt_G_32.zero_grad()
                output_fake_32 = D_32(generated_high_freqs_32).view(-1)
                adv_loss_32 = criterion(output_fake_32, torch.ones_like(output_fake_32))
                l1_loss_32 = l1_loss(generated_high_freqs_32, real_high_freqs)
                loss_gen_32 = adv_loss_32 + l1_loss_32
                loss_gen_32.backward()
                opt_G_32.step()

                total_g32_loss += loss_gen_32.item()

                recon_imgs = images_G16 + generated_high_freqs_32
                if i % 100 == 0:
                    helpMe.save_generated_images(genH_realH=torch.cat([real_high_freqs, generated_high_freqs_32], dim=0),
                                                 recon=torch.cat([images_G16, recon_imgs], dim=0), epoch=epoch, i=i,
                                                 path=checkpoint_dir, res='32', a='')
                    
                del images, real_high_freqs, generated_high_freqs_32, output_real_32, output_fake_32, loss_disc_real_32, loss_disc_fake_32, adv_loss_32, l1_loss_32, recon_imgs
                
                
                t.set_description(f'Epoch [{epoch}/{num_epochs}]')
                t.set_postfix({
                    'D32_loss': f'{loss_disc_32:.3f}',
                    'G32_loss': f'{loss_gen_32:.3f}',
                    'D16_loss': f'{loss_disc_16:.3f}',
                    'G16_loss': f'{loss_gen_16:.3f}',
                    'D8_loss': f'{loss_D_8:.3f}',
                    'G8_loss': f'{loss_G_8:.3f}'
                })

                wandb.log({
                    'D32_loss': loss_disc_32.item(),
                    'G32_loss': loss_gen_32.item(),
                    'D16_loss': loss_disc_16.item(),
                    'G16_loss': loss_gen_16.item(),
                    'D8_loss': loss_D_8.item(),
                    'G8_loss': loss_G_8.item()
                })
                
                
                
                torch.cuda.empty_cache()

        avg_d32_loss = total_d32_loss / len(dataloader)
        avg_g32_loss = total_g32_loss / len(dataloader)
        avg_d16_loss = total_d16_loss / len(dataloader)
        avg_g16_loss = total_g16_loss / len(dataloader)
        avg_d8_loss = total_d8_loss / (len(dataloader)*5)
        avg_g8_loss = total_g8_loss / (len(dataloader)*5)
        print(f"Epoch {epoch}: Loss D32: {avg_d32_loss:.4f}, Loss G32: {avg_g32_loss:.4f}, Loss D16: {avg_d16_loss:.4f}, Loss G16: {avg_g16_loss:.4f}, Loss D8: {avg_d8_loss:.4f}, Loss G8: {avg_g8_loss:.4f}")

        # Save generated images for visualization
        # Initialize the LAPGAN model with the save path and device
        # save_path = 'path/to/save/images'
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        lapgan_model = LAPGAN(generators, discriminators, checkpoint_dir, device)
        lapgan_model.to(device)

        # During training, call the forward method with epoch and iteration (batch index)
        
        fixed_noise = torch.randn(10, noise_dim, device=device)
        fixed_labels = torch.arange(0, 10, dtype=torch.long, device=device)
        # noise = torch.randn(batch_size, noise_dim, device=device)
        lapgan_model(fixed_noise, fixed_labels, epoch, i)

        # Save the model
        save_model(generators, discriminators, opt_G_32, opt_D_32, opt_G_16, opt_D_16, opt_G_8, opt_D_8, epoch, checkpoint_dir)

# Assuming dataset and other parameters are properly defined elsewhere
gen32 = Unet.UNetGenerator(in_channels=3, out_channels=3,)
disc32 = Unet.Discriminator(in_channels=3, )
gen16 = Unet.UNetGenerator(features=[64, 128, 256],in_channels=3, out_channels=3)
disc16 = Unet.Discriminator(features=[64, 128, 256],in_channels=3,)
gen8 = Unet.Generator_L2()
disc8 = Unet.Discriminator_L2()

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

generators = (gen32, gen16, gen8)
discriminators = (disc32, disc16, disc8)

checkpoint_dir = f"Models/{model_name}3/"

train_gan(generators, discriminators, dataloader, num_epochs=epochs, batch_size=32, checkpoint_dir=checkpoint_dir)

wandb.finish()  # Finish the wandb run


Models/UNet_Lapgan3/checkpoint.pth
Starting training from scratch.


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33maruntd008[0m ([33maruntd08[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

Epoch [1/110]: 100%|██████████| 1562/1562 [04:19<00:00,  6.03it/s, D32_loss=0.731, G32_loss=1.404, D16_loss=0.696, G16_loss=1.243, D8_loss=0.446, G8_loss=1.510]


Epoch 1: Loss D32: 0.6370, Loss G32: 2.0579, Loss D16: 0.6833, Loss G16: 1.4286, Loss D8: 0.3986, Loss G8: 2.1082


Epoch [2/110]: 100%|██████████| 1562/1562 [06:33<00:00,  3.97it/s, D32_loss=0.665, G32_loss=1.703, D16_loss=0.687, G16_loss=1.232, D8_loss=0.606, G8_loss=1.001]


Epoch 2: Loss D32: 0.6654, Loss G32: 1.5306, Loss D16: 0.6940, Loss G16: 1.2358, Loss D8: 0.5618, Loss G8: 1.2355


Epoch [3/110]:   3%|▎         | 48/1562 [00:14<07:09,  3.53it/s, D32_loss=0.663, G32_loss=1.812, D16_loss=0.704, G16_loss=1.207, D8_loss=0.583, G8_loss=1.091][34m[1mwandb[0m: Network error resolved after 0:00:20.540463, resuming normal operation.
Epoch [3/110]: 100%|██████████| 1562/1562 [07:31<00:00,  3.46it/s, D32_loss=0.587, G32_loss=2.019, D16_loss=0.670, G16_loss=1.339, D8_loss=0.607, G8_loss=0.954]


Epoch 3: Loss D32: 0.6106, Loss G32: 1.8752, Loss D16: 0.6913, Loss G16: 1.2542, Loss D8: 0.6197, Loss G8: 0.9963


Epoch [4/110]: 100%|██████████| 1562/1562 [06:17<00:00,  4.13it/s, D32_loss=0.404, G32_loss=2.057, D16_loss=0.728, G16_loss=1.237, D8_loss=0.600, G8_loss=0.910]


Epoch 4: Loss D32: 0.5930, Loss G32: 1.9882, Loss D16: 0.6913, Loss G16: 1.2579, Loss D8: 0.6353, Loss G8: 0.9441


Epoch [5/110]: 100%|██████████| 1562/1562 [03:53<00:00,  6.69it/s, D32_loss=0.441, G32_loss=2.305, D16_loss=0.728, G16_loss=1.230, D8_loss=0.643, G8_loss=0.827]


Epoch 5: Loss D32: 0.5821, Loss G32: 2.0135, Loss D16: 0.6888, Loss G16: 1.2672, Loss D8: 0.6484, Loss G8: 0.8973


Epoch [6/110]: 100%|██████████| 1562/1562 [03:51<00:00,  6.74it/s, D32_loss=0.657, G32_loss=2.018, D16_loss=0.696, G16_loss=1.293, D8_loss=0.614, G8_loss=0.852]


Epoch 6: Loss D32: 0.6034, Loss G32: 1.9114, Loss D16: 0.6848, Loss G16: 1.2868, Loss D8: 0.6586, Loss G8: 0.8573


Epoch [7/110]: 100%|██████████| 1562/1562 [06:17<00:00,  4.14it/s, D32_loss=0.692, G32_loss=1.740, D16_loss=0.646, G16_loss=1.312, D8_loss=0.638, G8_loss=0.805]


Epoch 7: Loss D32: 0.5520, Loss G32: 2.1284, Loss D16: 0.6800, Loss G16: 1.3043, Loss D8: 0.6656, Loss G8: 0.8320


Epoch [8/110]:   2%|▏         | 39/1562 [00:12<07:54,  3.21it/s, D32_loss=0.556, G32_loss=2.236, D16_loss=0.704, G16_loss=1.316, D8_loss=0.681, G8_loss=0.769][34m[1mwandb[0m: Network error resolved after 0:00:21.483126, resuming normal operation.
Epoch [8/110]: 100%|██████████| 1562/1562 [06:28<00:00,  4.02it/s, D32_loss=0.524, G32_loss=2.183, D16_loss=0.661, G16_loss=1.327, D8_loss=0.646, G8_loss=0.768]


Epoch 8: Loss D32: 0.5670, Loss G32: 2.1514, Loss D16: 0.6836, Loss G16: 1.2957, Loss D8: 0.6690, Loss G8: 0.8200


Epoch [9/110]:   5%|▍         | 74/1562 [00:13<07:03,  3.51it/s, D32_loss=0.523, G32_loss=2.187, D16_loss=0.673, G16_loss=1.295, D8_loss=0.686, G8_loss=0.756][34m[1mwandb[0m: Network error resolved after 0:00:21.286607, resuming normal operation.
Epoch [9/110]: 100%|██████████| 1562/1562 [07:22<00:00,  3.53it/s, D32_loss=0.303, G32_loss=3.015, D16_loss=0.682, G16_loss=1.298, D8_loss=0.638, G8_loss=0.831]


Epoch 9: Loss D32: 0.5348, Loss G32: 2.2392, Loss D16: 0.6793, Loss G16: 1.3120, Loss D8: 0.6712, Loss G8: 0.8087


Epoch [10/110]:   1%|          | 9/1562 [00:02<07:17,  3.55it/s, D32_loss=0.382, G32_loss=2.405, D16_loss=0.666, G16_loss=1.336, D8_loss=0.707, G8_loss=0.795][34m[1mwandb[0m: Network error resolved after 0:00:21.467540, resuming normal operation.
Epoch [10/110]: 100%|██████████| 1562/1562 [07:28<00:00,  3.48it/s, D32_loss=0.703, G32_loss=2.039, D16_loss=0.706, G16_loss=1.312, D8_loss=0.681, G8_loss=0.767]


Epoch 10: Loss D32: 0.4715, Loss G32: 2.5885, Loss D16: 0.6725, Loss G16: 1.3318, Loss D8: 0.6729, Loss G8: 0.8052


Epoch [11/110]:   0%|          | 4/1562 [00:01<09:17,  2.79it/s, D32_loss=0.157, G32_loss=2.834, D16_loss=0.685, G16_loss=1.287, D8_loss=0.668, G8_loss=0.788][34m[1mwandb[0m: Network error resolved after 0:00:20.916872, resuming normal operation.
Epoch [11/110]: 100%|██████████| 1562/1562 [07:32<00:00,  3.45it/s, D32_loss=0.647, G32_loss=2.406, D16_loss=0.685, G16_loss=1.289, D8_loss=0.642, G8_loss=0.790]


Epoch 11: Loss D32: 0.4765, Loss G32: 2.6049, Loss D16: 0.6773, Loss G16: 1.3227, Loss D8: 0.6763, Loss G8: 0.7888


Epoch [12/110]:   2%|▏         | 28/1562 [00:08<07:22,  3.47it/s, D32_loss=0.591, G32_loss=2.419, D16_loss=0.676, G16_loss=1.257, D8_loss=0.629, G8_loss=0.839][34m[1mwandb[0m: Network error resolved after 0:00:21.635574, resuming normal operation.
Epoch [12/110]: 100%|██████████| 1562/1562 [05:19<00:00,  4.89it/s, D32_loss=0.702, G32_loss=2.230, D16_loss=0.710, G16_loss=1.285, D8_loss=0.671, G8_loss=0.777]


Epoch 12: Loss D32: 0.4795, Loss G32: 2.6018, Loss D16: 0.6756, Loss G16: 1.3296, Loss D8: 0.6782, Loss G8: 0.7818


Epoch [13/110]: 100%|██████████| 1562/1562 [04:58<00:00,  5.23it/s, D32_loss=0.239, G32_loss=3.167, D16_loss=0.709, G16_loss=1.304, D8_loss=0.687, G8_loss=0.749]


Epoch 13: Loss D32: 0.4173, Loss G32: 3.0238, Loss D16: 0.6692, Loss G16: 1.3528, Loss D8: 0.6813, Loss G8: 0.7678


Epoch [14/110]: 100%|██████████| 1562/1562 [06:05<00:00,  4.27it/s, D32_loss=0.268, G32_loss=3.712, D16_loss=0.638, G16_loss=1.450, D8_loss=0.683, G8_loss=0.741]


Epoch 14: Loss D32: 0.4111, Loss G32: 3.1145, Loss D16: 0.6699, Loss G16: 1.3532, Loss D8: 0.6815, Loss G8: 0.7685


KeyboardInterrupt: 

[34m[1mwandb[0m: Network error resolved after 0:00:21.775047, resuming normal operation.
