In [36]:
import helpMe
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
from PIL import Image, ImageFilter

from torch.utils.data import DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as T


device = helpMe.get_default_device()
device

device(type='cuda')

## Configrations

In [37]:
model_name = "UNet_8"
image_size = 8
batch_size = 32
# z_dim = 128
# DATA_DIR = './imageNet_lp/torch_image_folder/mnt/volume_sfo3_01/imagenet-lt/ImageDataset/train'
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
channels =1
epochs = 110

In [38]:
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128]):
        super(UNetGenerator, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.bottleneck = nn.Conv2d(features[-1], features[-1]*2, kernel_size=3, padding=1)

        # Downsampling part
        for feature in features:
            self.downs.append(self.conv_block(in_channels, feature))
            in_channels = feature

        # Upsampling part
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(self.conv_block(feature*2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = nn.MaxPool2d(kernel_size=2, stride=2)(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = nn.functional.interpolate(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)

        return self.final_conv(x)


In [39]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=1, features=[64, 128]):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1)
        self.conv_layers = nn.ModuleList()
        
        in_features = features[0]
        for feature in features[1:]:
            self.conv_layers.append(self._block(in_features, feature, stride=2))
            in_features = feature
        
        self.final_conv = nn.Conv2d(in_features, 1, kernel_size=2, stride=1, padding=0)
        
    def _block(self, in_channels, out_channels, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )
    
    def forward(self, x):
        x = nn.LeakyReLU(0.2, inplace=True)(self.conv1(x))
        for layer in self.conv_layers:
            x = layer(x)
            # print(x.shape)
            
        return torch.sigmoid(self.final_conv(x))


In [40]:
G,D = UNetGenerator(), Discriminator()

In [41]:
print('Number of params in G: {} D: {}'.format(
*[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))

Number of params in G: 1273153 D: 133057


In [42]:
transforms = T.Compose([
    T.Resize(image_size),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])
dataset = Datasets.MNIST(root='./Datasxts/MNIST/', train=True, download=True,transform=transforms)

In [43]:
from PIL import Image, ImageFilter

def to_gaus(imgs):
    smoothed_imgs = []
    higher_freq = []

    for img_tensor in imgs:

        img = T.ToPILImage()(img_tensor)
         
        S_img = img.filter(ImageFilter.GaussianBlur(radius=1))  # Adjust the radius as needed
              
        # H_img = T.ToPILImage()(H_img)
        S_img = T.ToTensor()(S_img)
        H_img = img_tensor - S_img
        higher_freq.append(H_img)
        smoothed_imgs.append(S_img)

    smoothed_imgs = torch.stack(smoothed_imgs)
    higher_freq = torch.stack(higher_freq)
    torchvision.utils.save_image(smoothed_imgs.detach(), f"smoooooo_8.png", normalize=True,nrow=8)
    return smoothed_imgs,higher_freq

In [44]:
import os
def save_generated_images(genH_realH, recon, epoch,i, path, device):
    os.makedirs(f"{path}Generated", exist_ok=True)
    torchvision.utils.save_image(genH_realH.detach(), f"{path}Generated/{epoch}_{i}_generated_images_epoch.png", normalize=True,nrow=8)
    torchvision.utils.save_image(recon.detach(), f"{path}Generated/{epoch}_{i}_generated_recon_epoch.png", normalize=True,nrow=8)
    
    


In [45]:
dataloader = DataLoader(dataset, 32, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

In [46]:
# AMP Scalers
scaler_G = GradScaler()
scaler_D = GradScaler()

def train_gan(generator, discriminator, dataloader, num_epochs, batch_size, checkpoint_dir=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)
    
    # Optimizers
    opt_gen = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_disc = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    
    # Loss functions
    criterion = nn.BCELoss()
    l1_loss = nn.L1Loss()

 
    
    start_epoch = 1
    if checkpoint_dir:
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth')
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=device)
            generator.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
            opt_gen.load_state_dict(checkpoint['optimizer_G_state_dict'])
            opt_disc.load_state_dict(checkpoint['optimizer_D_state_dict'])
            scaler_G.load_state_dict(checkpoint['scaler_G'])
            scaler_D.load_state_dict(checkpoint['scaler_D'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"Resuming training from epoch {start_epoch}.")

    for epoch in range(start_epoch, num_epochs + 1):
        total_d_loss = 0.0
        total_g_loss = 0.0
        
        with tqdm(enumerate(dataloader), total=len(dataloader)) as t:
            for i, (images, _) in t:
                smoothed_images, real_high_freqs = to_gaus(imgs=images)
                smoothed_images = smoothed_images.to(device)
                real_high_freqs = real_high_freqs.to(device)

                # Train Discriminator
                opt_disc.zero_grad()
                output_real = discriminator(real_high_freqs).view(-1)
                loss_disc_real = criterion(output_real, torch.ones_like(output_real))
                generated_high_freqs = generator(smoothed_images)
                output_fake = discriminator(generated_high_freqs.detach()).view(-1)
                loss_disc_fake = criterion(output_fake, torch.zeros_like(output_fake))
                loss_disc = (loss_disc_real + loss_disc_fake) / 2
                loss_disc.backward()
                opt_disc.step()

                total_d_loss += loss_disc.item()

                # Train Generator
                opt_gen.zero_grad()
                output_fake = discriminator(generated_high_freqs).view(-1)
                loss_gen = criterion(output_fake, torch.ones_like(output_fake))
                loss_l1 = l1_loss(generated_high_freqs, real_high_freqs)
                loss_generator = loss_gen + 100 * loss_l1
                loss_generator.backward()
                opt_gen.step()

                total_g_loss += loss_generator.item()
                
                t.set_description(f'Epoch [{epoch}/{num_epochs}]')
                t.set_postfix({'D_loss': f'{loss_disc:.3f}',
                               'G_loss': f'{loss_generator:.3f}'})
                
                if i % 100 == 0:
                    recon_imgs = smoothed_images + generated_high_freqs
                    save_generated_images(torch.cat([real_high_freqs,generated_high_freqs],dim=0),torch.cat([images.to(device),recon_imgs],dim=0), epoch,i, checkpoint_dir, device)
                
                del images,smoothed_images, real_high_freqs, generated_high_freqs, output_real, output_fake, loss_disc_real, loss_disc_fake, loss_disc, loss_gen, loss_l1, loss_generator
                torch.cuda.empty_cache()
            
        avg_d_loss = total_d_loss / len(dataloader)
        avg_g_loss = total_g_loss / len(dataloader)
        print(f"Epoch [{epoch}/{num_epochs}] Loss D: {avg_d_loss:.4f}, loss G: {avg_g_loss:.4f}")

        # Save generated images

        # Save the model
        save_model(generator, discriminator, opt_gen, opt_disc, epoch, checkpoint_dir)
        
def save_model(generator, discriminator, opt_gen, opt_disc, epoch, checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth')
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': opt_gen.state_dict(),
        'optimizer_D_state_dict': opt_disc.state_dict(),
        'scaler_G': scaler_G.state_dict(),
        'scaler_D': scaler_D.state_dict()
    }, checkpoint_path)



generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)

checkpoint_dir = f"Models/{model_name}/"

train_gan(generator, discriminator, dataloader, num_epochs =epochs, batch_size=batch_size, checkpoint_dir=checkpoint_dir)


  0%|          | 0/1875 [00:00<?, ?it/s]

Epoch [1/110]: 100%|██████████| 1875/1875 [00:58<00:00, 32.24it/s, D_loss=0.583, G_loss=7.934] 


Epoch [1/110] Loss D: 0.5633, loss G: 11.6695


Epoch [2/110]: 100%|██████████| 1875/1875 [00:56<00:00, 32.97it/s, D_loss=0.664, G_loss=5.692]


Epoch [2/110] Loss D: 0.6322, loss G: 6.5032


Epoch [3/110]: 100%|██████████| 1875/1875 [00:55<00:00, 33.56it/s, D_loss=0.757, G_loss=4.312]


Epoch [3/110] Loss D: 0.6558, loss G: 5.1506


Epoch [4/110]: 100%|██████████| 1875/1875 [00:57<00:00, 32.49it/s, D_loss=0.587, G_loss=3.834]


Epoch [4/110] Loss D: 0.6623, loss G: 4.3553


Epoch [5/110]: 100%|██████████| 1875/1875 [00:56<00:00, 33.09it/s, D_loss=0.707, G_loss=4.170]


Epoch [5/110] Loss D: 0.6681, loss G: 3.9533


Epoch [6/110]: 100%|██████████| 1875/1875 [00:56<00:00, 33.15it/s, D_loss=0.599, G_loss=4.150]


Epoch [6/110] Loss D: 0.6699, loss G: 3.6988


Epoch [7/110]:  97%|█████████▋| 1815/1875 [00:53<00:01, 33.64it/s, D_loss=0.724, G_loss=3.214]


KeyboardInterrupt: 

In [None]:
data_iter = iter(dataloader)
a,n= next(data_iter)

In [None]:
to_gaus(a)