In [1]:
import os
import pandas as pd
import random
from collections import OrderedDict
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from torch.nn.utils import spectral_norm
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from torchvision.utils import make_grid
import warnings
warnings.filterwarnings("ignore")
#from torchsummaryX import summary

In [2]:
class ccbn(nn.Module):
    def __init__(self, input_size, output_size, eps=1e-4, momentum=0.1):
        super(ccbn, self).__init__()
        self.output_size, self.input_size = output_size, input_size
        # Prepare gain and bias layers
        self.gain = spectral_norm(nn.Linear(input_size, output_size, bias = False), eps = 1e-4)
        self.bias = spectral_norm(nn.Linear(input_size, output_size, bias = False), eps = 1e-4)
        # epsilon to avoid dividing by 0
        self.eps = eps
        # Momentum
        self.momentum = momentum
        
        self.register_buffer('stored_mean', torch.zeros(output_size))
        self.register_buffer('stored_var',  torch.ones(output_size))
    
    def forward(self, x, y):
        # Calculate class-conditional gains and biases
        gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
        bias = self.bias(y).view(y.size(0), -1, 1, 1)
        out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
                          self.training, 0.1, self.eps)
        return out * gain + bias
    
    def extra_repr(self):
        s = 'out: {output_size}, in: {input_size},'
        return s.format(**self.__dict__)

In [3]:
class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation = nn.ReLU(inplace = False)):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        
        out = self.gamma*out + x
        return out

In [4]:
class GeneratorResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, upsample = None, embed_dim = 128, dim_z = 128):
        super(GeneratorResBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_channels = self.in_channels // 4
        
        self.conv1 = spectral_norm(nn.Conv2d(self.in_channels, self.hidden_channels, kernel_size = 1, padding = 0), eps = 1e-4)
        self.conv2 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)
        self.conv3 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)
        self.conv4 = spectral_norm(nn.Conv2d(self.hidden_channels, self.out_channels, kernel_size = 1, padding = 0), eps = 1e-4)
        
        self.bn1 = ccbn(input_size = (3 * embed_dim) + dim_z, output_size = self.in_channels)
        self.bn2 = ccbn(input_size = (3 * embed_dim) + dim_z, output_size = self.hidden_channels)
        self.bn3 = ccbn(input_size = (3 * embed_dim) + dim_z, output_size = self.hidden_channels)
        self.bn4 = ccbn(input_size = (3 * embed_dim) + dim_z, output_size = self.hidden_channels)
        
        self.activation = nn.ReLU(inplace=False)
        
        self.upsample = upsample
        
    def forward(self,x,y):
        # Project down to channel ratio
        h = self.conv1(self.activation(self.bn1(x, y)))
        # Apply next BN-ReLU
        h = self.activation(self.bn2(h, y))
        # Drop channels in x if necessary
        if self.in_channels != self.out_channels:
            x = x[:, :self.out_channels]      
        # Upsample both h and x at this point  
        if self.upsample:
            h = self.upsample(h)
            x = self.upsample(x)
        # 3x3 convs
        h = self.conv2(h)
        h = self.conv3(self.activation(self.bn3(h, y)))
        # Final 1x1 conv
        h = self.conv4(self.activation(self.bn4(h, y)))
        return h + x

In [5]:
class Generator(nn.Module):
    def __init__(self, G_ch = 64, dim_z=128, bottom_width=4, img_channels = 1,
                 init = 'ortho',n_classes_temp = 7, n_classes_time = 8, n_classes_cool = 4, embed_dim = 128):
        super(Generator, self).__init__()
        self.ch = G_ch
        self.dim_z = dim_z
        self.bottom_width = bottom_width
        self.init = init
        self.img_channels = img_channels

        self.embed_temp = nn.Embedding(n_classes_temp, embed_dim)
        self.embed_time = nn.Embedding(n_classes_time, embed_dim)
        self.embed_cool = nn.Embedding(n_classes_cool, embed_dim)
        
        self.linear = spectral_norm(nn.Linear(dim_z + (3 * embed_dim), 16 * self.ch * (self.bottom_width **2)), eps = 1e-4)
        
        self.blocks = nn.ModuleList([
                GeneratorResBlock(16*self.ch, 16*self.ch),
                GeneratorResBlock(16*self.ch, 16*self.ch, upsample =  nn.Upsample(scale_factor = 2)),
                GeneratorResBlock(16*self.ch, 16*self.ch),
                GeneratorResBlock(16*self.ch, 8*self.ch, upsample =  nn.Upsample(scale_factor = 2)),
                GeneratorResBlock(8*self.ch, 8*self.ch),
                GeneratorResBlock(8*self.ch, 8*self.ch, upsample =  nn.Upsample(scale_factor = 2)),
                GeneratorResBlock(8*self.ch, 8*self.ch),
                GeneratorResBlock(8*self.ch, 4*self.ch, upsample =  nn.Upsample(scale_factor = 2)),
                Self_Attn(4*self.ch),
                GeneratorResBlock(4*self.ch, 4*self.ch),
                GeneratorResBlock(4*self.ch, 2*self.ch, upsample =  nn.Upsample(scale_factor = 2)),
                GeneratorResBlock(2*self.ch, 2*self.ch),
                GeneratorResBlock(2*self.ch,  self.ch, upsample =  nn.Upsample(scale_factor = 2))
        ])
        
        self.final_layer = nn.Sequential(
                nn.BatchNorm2d(self.ch),
                nn.ReLU(inplace = False),
                spectral_norm(nn.Conv2d(self.ch, self.img_channels, kernel_size = 3, padding = 1)),
                nn.Tanh()
        )
        
        self.init_weights()
                                    
    def init_weights(self):
        print(f"Weight initialization : {self.init}")
        self.param_count = 0
        for module in self.modules():
            if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
                if self.init == 'ortho':
                    torch.nn.init.orthogonal_(module.weight)
                elif self.init == 'N02':
                    torch.nn.init.normal_(module.weight, 0, 0.02)
                elif self.init in ['glorot', 'xavier']:
                    torch.nn.init.xavier_uniform_(module.weight)
                else:
                    print('Init style not recognized...')
                self.param_count += sum([p.data.nelement() for p in module.parameters()])
        print("Param count for G's initialized parameters: %d Million" % (self.param_count/1000000))
        
        
    def forward(self,z , y_temp, y_time, y_cool):
        y_temp = self.embed_temp(y_temp)
        y_time = self.embed_time(y_time)
        y_cool = self.embed_cool(y_cool)
        z = torch.cat([z, y_temp, y_time, y_cool], 1)     
        # First linear layer
        h = self.linear(z)
        # Reshape
        h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)    
        # Loop over blocks
        for i, block in enumerate(self.blocks):
            if i != 8:
                h = block(h, z)
            else:
                h = block(h)
        # Apply batchnorm-relu-conv-tanh at output
        h = self.final_layer(h)
        return h

In [6]:
class DiscriminatorResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, preactivation=True, 
                 downsample=None,channel_ratio=4):
        super(DiscriminatorResBlock, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
        self.hidden_channels = self.out_channels // channel_ratio
        self.preactivation = preactivation
        self.activation = nn.ReLU(inplace=False)
        self.downsample = downsample
        
        # Conv layers
        self.conv1 = spectral_norm(nn.Conv2d(self.in_channels, self.hidden_channels, 
                                 kernel_size=1, padding=0), eps = 1e-4)
        self.conv2 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)
        self.conv3 = spectral_norm(nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size = 3, padding = 1), eps = 1e-4)
        self.conv4 = spectral_norm(nn.Conv2d(self.hidden_channels, self.out_channels, 
                                 kernel_size=1, padding=0), eps = 1e-4)
                                 
        self.learnable_sc = True if (in_channels != out_channels) else False
        if self.learnable_sc:
            self.conv_sc = spectral_norm(nn.Conv2d(in_channels, out_channels - in_channels, 
                                     kernel_size=1, padding=0), eps = 1e-4)
            
    def shortcut(self, x):
        if self.downsample:
            x = self.downsample(x)
        if self.learnable_sc:
            x = torch.cat([x, self.conv_sc(x)], 1)    
        return x
    
    def forward(self, x):
        # 1x1 bottleneck conv
        h = self.conv1(F.relu(x))
        # 3x3 convs
        h = self.conv2(self.activation(h))
        h = self.conv3(self.activation(h))
        # relu before downsample
        h = self.activation(h)
        # downsample
        if self.downsample:
            h = self.downsample(h)     
        # final 1x1 conv
        h = self.conv4(h)
        return h + self.shortcut(x)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, D_ch = 64, img_channels = 1, output_dim = 1,
                 init = 'ortho',n_classes_temp = 7, n_classes_time = 8, n_classes_cool = 4):
        super(Discriminator, self).__init__()
        self.ch = D_ch
        self.init = init
        self.img_channels = img_channels
        self.output_dim = output_dim
        
        # Prepare model
        # Stem convolution
        self.input_conv = spectral_norm(nn.Conv2d(self.img_channels, self.ch, kernel_size = 3, padding = 1), eps = 1e-4)
        
        self.blocks = nn.Sequential(
                DiscriminatorResBlock(self.ch, 2*self.ch, downsample = nn.AvgPool2d(2)),
                DiscriminatorResBlock(2*self.ch, 2*self.ch),
                DiscriminatorResBlock(2*self.ch, 4*self.ch, downsample = nn.AvgPool2d(2)),
                DiscriminatorResBlock(4*self.ch, 4*self.ch),
                Self_Attn(4*self.ch),
                DiscriminatorResBlock(4*self.ch, 8*self.ch, downsample = nn.AvgPool2d(2)),
                DiscriminatorResBlock(8*self.ch, 8*self.ch),
                DiscriminatorResBlock(8*self.ch, 8*self.ch, downsample = nn.AvgPool2d(2)),
                DiscriminatorResBlock(8*self.ch, 8*self.ch),
                DiscriminatorResBlock(8*self.ch, 16*self.ch, downsample = nn.AvgPool2d(2)),
                DiscriminatorResBlock(16*self.ch, 16*self.ch),
                DiscriminatorResBlock(16*self.ch, 16*self.ch, downsample = nn.AvgPool2d(2)),
                DiscriminatorResBlock(16*self.ch, 16*self.ch),
        )
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = spectral_norm(nn.Linear(16*self.ch, output_dim), eps = 1e-4)
        # Embedding for projection discrimination
        self.embed_temp = nn.Embedding(n_classes_temp, 16*self.ch)
        self.embed_time = nn.Embedding(n_classes_time, 16*self.ch)
        self.embed_cool = nn.Embedding(n_classes_cool, 16*self.ch)
        
        self.init_weights()
    
    def init_weights(self):
        print(f"Weight initialization : {self.init}")
        self.param_count = 0
        for module in self.modules():
            if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
                if self.init == 'ortho':
                    torch.nn.init.orthogonal_(module.weight)
                elif self.init == 'N02':
                    torch.nn.init.normal_(module.weight, 0, 0.02)
                elif self.init in ['glorot', 'xavier']:
                    torch.nn.init.xavier_uniform_(module.weight)
                else:
                    print('Init style not recognized...')
                self.param_count += sum([p.data.nelement() for p in module.parameters()])
        print("Param count for D's initialized parameters: %d Million" % (self.param_count/1000000))
        
    def forward(self, x, y_temp, y_time, y_cool):
        # Run input conv
        h = self.input_conv(x)
        # Blocks
        h = self.blocks(h)
        # Apply global sum pooling as in SN-GAN
        h = torch.sum(nn.ReLU(inplace = False)(h), [2, 3])
        # Get initial class-unconditional output
        out = self.linear(h)
        # Get projection of final featureset onto class vectors and add to evidence
        out = out + torch.sum(self.embed_temp(y_temp) * h, 1, keepdim=True) + torch.sum(self.embed_time(y_time) * h, 1, keepdim=True) + torch.sum(self.embed_cool(y_cool) * h, 1, keepdim=True)
        return out

In [8]:
class MicrographDataset(Dataset):
    """
    A custom Dataset class for Micrograph data which returns the following
    # Micrograph image
    # Inputs : Anneal Temperature , Anneal Time and Type of cooling used
    ------------------------------------------------------------------------------------
    Attributes
    
    df : pandas.core.frame.DataFrame
        A Dataframe that contains the proper entries (i.e. dataframe corresponding to new_metadata.xlsx)
    root_dir : str
        The path of the folder where the images are located
    transform : torchvision.transforms.transforms.Compose
        The transforms that are to be applied to the loaded images
    """
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.transform = transform
        self.root_dir = root_dir
        
    def __len__(self):
        return len(self.df)    
    
    def __getitem__(self, idx):
        temp_dict = {970: 0, 800: 1, 900: 2, 1100: 3, 1000: 4, 700: 5, 750: 6}
        time_dict = {90: 0, 1440: 1, 180: 2, 5: 3, 480: 4, 5100: 5, 60: 6, 2880: 7}
        cooling_dict = {'Q': 0, 'FC': 1, 'AR': 2, '650-1H': 3}
        row = self.df.loc[idx]
        img_name = row['path']
        img_path = self.root_dir + '/' + 'Cropped' + img_name
        anneal_temp = temp_dict[row['anneal_temperature']]
        if row['anneal_time_unit'] == 'H':
            anneal_time = int(row['anneal_time']) * 60
        else:
            anneal_time = row['anneal_time']
        anneal_time = time_dict[anneal_time]
        cooling_type = cooling_dict[row['cool_method']]
        img = Image.open(img_path)
        img = img.convert('L')
        if self.transform:
            img = self.transform(img)
        return img , anneal_temp, anneal_time, cooling_type

In [9]:
class MicrographBigGAN(pl.LightningModule):
    def __init__(self, root_dir, df_dir, batch_size, lr = 2.5e-5):
        super().__init__()
        self.save_hyperparameters()
        self.root_dir = root_dir
        self.df_dir = df_dir
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.batch_size = batch_size 
        self.lr = lr
        
    def forward(self, z, y_temp, y_time, y_cool):
        return self.generator(z, y_temp, y_time, y_cool)
    
    def discriminator_loss(self, disc_real, disc_fake):
        loss_real = torch.mean(F.relu(1. - disc_real))
        loss_fake = torch.mean(F.relu(1. + disc_fake))
        return loss_real + loss_fake
    
    def generator_loss(self, disc_fake):
        loss = -torch.mean(disc_fake)
        return loss
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        real, y_temp, y_time, y_cool = batch
        z = torch.randn(real.shape[0], 128)
        z = z.type_as(real)
        
        if optimizer_idx == 0:
            self.fake = self(z, y_temp, y_time, y_cool)
            sample_imgs = real[:4]
            grid = make_grid(sample_imgs)
            self.logger.experiment.add_image('generated_images', grid, 0)
        
            disc_real = self.discriminator(real, y_temp, y_time, y_cool)
            disc_fake = self.discriminator(self.fake, y_temp, y_time, y_cool)
            d_loss = self.discriminator_loss(disc_real, disc_fake)
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output
        
        if optimizer_idx == 1:
            self.fake = self(z,y_temp, y_time, y_cool)
            g_loss = self.generator_loss(self.discriminator(self.fake, y_temp, y_time, y_cool))
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output
        
    def configure_optimizers(self):
        opt_g = optim.Adam(self.generator.parameters(), lr = self.lr)
        opt_d = optim.Adam(self.discriminator.parameters(), lr = self.lr)
        return (
            {'optimizer': opt_d, 'frequency': 2},
            {'optimizer': opt_g, 'frequency': 1}
        )
    
    def train_dataloader(self):
        img_transforms = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
            transforms.Normalize([0.5 for _ in range(1)],[0.5 for _ in range(1)]),
        ])
        df = pd.read_excel(self.df_dir)
        dataset = MicrographDataset(df,self.root_dir,transform = img_transforms)
        return DataLoader(dataset, batch_size=self.batch_size,shuffle=True)
    
    def on_epoch_end(self):
        # log sampled images
        grid = make_grid(self.fake[:4])
        self.logger.experiment.add_image('generated_images', grid, self.current_epoch)

In [10]:
ROOT_DIR = '../input/highcarbon-micrographs/For Training/Cropped'
DF_DIR = '../input/highcarbon-micrographs/new_metadata.xlsx'

In [11]:
gan = MicrographBigGAN(ROOT_DIR,DF_DIR,batch_size=8)

Weight initialization : ortho
Param count for G's initialized parameters: 29 Million
Weight initialization : ortho
Param count for D's initialized parameters: 9 Million


In [12]:
trainer = pl.Trainer(max_epochs=100, gpus=1 if torch.cuda.is_available() else 0)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [14]:
from urllib.request import urlopen
from io import BytesIO
from zipfile import ZipFile
from subprocess import Popen
from os import chmod
from os.path import isfile
import json
import time
import psutil

def launch_tensorboard():
    tb_process, ngrok_process = None, None
    
    # Launch TensorBoard
    if not is_process_running('tensorboard'):
        tb_command = 'tensorboard --logdir ./lightning_logs --host 0.0.0.0 --port 6006'
        tb_process = run_cmd_async_unsafe(tb_command)
    
    # Install ngrok
    if not isfile('./ngrok'):
        ngrok_url = 'https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip'
        download_and_unzip(ngrok_url)
        chmod('./ngrok', 0o755)

    # Create ngrok tunnel and print its public URL
    if not is_process_running('ngrok'):
        ngrok_process = run_cmd_async_unsafe('./ngrok http 6006')
        time.sleep(1) # Waiting for ngrok to start the tunnel
    ngrok_api_res = urlopen('http://127.0.0.1:4040/api/tunnels', timeout=10)
    ngrok_api_res = json.load(ngrok_api_res)
    assert len(ngrok_api_res['tunnels']) > 0, 'ngrok tunnel not found'
    tb_public_url = ngrok_api_res['tunnels'][0]['public_url']
    print(f'TensorBoard URL: {tb_public_url}')

    return tb_process, ngrok_process


def download_and_unzip(url, extract_to='.'):
    http_response = urlopen(url)
    zipfile = ZipFile(BytesIO(http_response.read()))
    zipfile.extractall(path=extract_to)


def run_cmd_async_unsafe(cmd):
    return Popen(cmd, shell=True)


def is_process_running(process_name):
    running_process_names = (proc.name() for proc in psutil.process_iter())
    return process_name in running_process_names


tb_process, ngrok_process = launch_tensorboard()

TensorBoard URL: http://c184c32a5645.ngrok.io


In [13]:
trainer.fit(gan)


  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 29.5 M
1 | discriminator | Discriminator | 9.1 M 


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1