In [17]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
import torchvision
from torchvision import transforms
import numpy as np
from tqdm import tqdm
from syndata import get_exo_locations
import pandas as pd
import glob
from PIL import Image
import os

In [2]:
class SynDataset(Dataset):

    def __init__(self, image_paths):

        self.image_paths = image_paths
        self.transform  = transforms.Compose([
        transforms.ToTensor(),
        ])


    def __len__(self,):

        return len(self.image_paths)

    def __getitem__(self, index):
        
        image_path = self.image_paths[index]
        image      = Image.open(image_path).convert('L')
        image = self.transform(image)

        if torch.isnan(image).any().item():
            torch.nan_to_num(image)
            
        return image
    

In [3]:
image_paths = glob.glob('/data/scratch/bariskurtkaya/dataset/PSF_INJECTION/*.png')

In [4]:
syndata        = SynDataset(image_paths=image_paths)
syndata_loader = DataLoader(dataset=syndata, batch_size=16)

In [5]:
idx, data = next(enumerate(syndata_loader))

In [7]:
def calculate_conv_dims(input_size,paddings:list,kernels:list,strides:list,maxpool:list):
    
    outputs = []
    outputs.append(input_size)
    for i in range(len(paddings)):
        
        output_size = (input_size + (2*paddings[i]) - (kernels[i] - 1) - 1)/strides[i] + 1
        if maxpool[i] != 0:
            output_size = (output_size  + (2*paddings[i]) - (maxpool[i]-1)-1)/2 +1
        
        outputs.append(int(output_size))
        input_size = output_size
        
    print(outputs)
    return outputs

In [8]:
kernels_enc = [7,7,7,7,7,7,7]
paddings_enc= [0,0,0,0,0,0,0]
strides_enc = [1,2,1,2,1,2,1]
maxpool = [0,0,0,0,0,0,0]

In [9]:
convdim_outputs = calculate_conv_dims(320,paddings_enc,kernels_enc,strides_enc,maxpool)

[320, 314, 154, 148, 71, 65, 30, 24]


In [10]:
def calculate_convtrans_dim(input_size,paddings:list,kernels:list,strides:list):
    outputs = []
    outputs.append(input_size)
    for i in range(len(paddings)):
        
        output_size = (input_size - 1) * strides[i]  -  2 * paddings[i] + kernels[i] - 1 + 1
        outputs.append(int(output_size))
        input_size = output_size
        
    print(outputs)
    return outputs

In [11]:
kernels_dec  = [7,7,7,7,7,7,10]
paddings_dec = [0,0,0,0,0,0,0]
strides_dec  = [1,2,1,2,1,2,1]

In [12]:
convtrans_outputs = calculate_convtrans_dim(24,paddings_dec,kernels_dec,strides_dec)

[24, 30, 65, 71, 147, 153, 311, 320]


In [14]:
class Exonet(nn.Module):
    
    def __init__(self, convdim_enc_outputs:list, kernels_enc:list, strides_enc:list, kernels_dec:list, strides_dec:list):
        
        super(Exonet,self).__init__()
        
        self.convdim = convdim_enc_outputs
        self.kernels_enc = kernels_enc
        self.strides_enc = strides_enc
        self.kernels_dec = kernels_dec
        self.strides_dec = strides_dec
        self.C       = 8 
        
        self.encoder  = nn.Sequential(
                        
            nn.Conv2d(in_channels=1, out_channels=self.C, stride=self.strides_enc[0], kernel_size=self.kernels_enc[0]), #1
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C, out_channels=self.C, stride=self.strides_enc[1], kernel_size=self.kernels_enc[1]), #2
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C, out_channels=self.C, stride=self.strides_enc[2], kernel_size=self.kernels_enc[2]), #3
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C, out_channels=self.C*2, stride=self.strides_enc[3], kernel_size=self.kernels_enc[3]), #4 
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_enc[4], kernel_size=self.kernels_enc[4]), #5
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_enc[5], kernel_size=self.kernels_enc[5]), #6
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_enc[6], kernel_size=self.kernels_enc[6]), #7
            nn.SiLU(),
            
        
        ) 
        
        self.fc1 = nn.Sequential(
        
                nn.Linear((self.C*2)*convdim_outputs[-1]**2,2048),
                nn.SiLU(),
                nn.Linear(2048,1024),
                nn.SiLU(),
        )

        self.latent = nn.Linear(1024,1024)

        self.fc2   = nn.Sequential(

                nn.Linear(1024,2048),
                nn.SiLU(),
                nn.Linear(2048,(self.C*2)*convdim_outputs[-1]**2),
                nn.SiLU(),

        )

        self.decoder = nn.Sequential(

                        
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_dec[0], kernel_size=self.kernels_dec[0]), #1
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_dec[1], kernel_size=self.kernels_dec[1]), #2
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_dec[2], kernel_size=self.kernels_dec[2]), #3
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C, stride=self.strides_dec[3], kernel_size=self.kernels_dec[3]), #4 
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C, out_channels=self.C, stride=self.strides_dec[4], kernel_size=self.kernels_dec[4]), #5
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C, out_channels=self.C, stride=self.strides_dec[5], kernel_size=self.kernels_dec[5]), #6
            nn.SiLU(),
            
            nn.Conv2d(in_channels=self.C, out_channels=1, stride=self.strides_dec[6], kernel_size=self.kernels_dec[6]), #7
            nn.SiLU(),
            
        ) 
        
    def forward(self,x):
        
        bs       = x.size(0)

        x       = self.encoder(x)
        x       = x.view(x.size(0),-1)

        x       = self.fc1(x)
        latents = self.latent(x)
        x       = self.fc2(latents)

        x       = x.view(bs,self.C*2,convdim_outputs[-1],convdim_outputs[-1])
        x       = self.decoder(x)
        
        return x
        

In [None]:
def conv_train(model,train_dataloader,optimizer,device,loss_fn,EPOCH=30):
    

    save_path = os.path.join('/home/sarperyn/sarperyurtseven/ProjectFiles', 'training_results')
    with tqdm(total = len(train_dataloader) * EPOCH) as tt:
        
        model.train()
        model = model.to(device)
        
        for epoch in range(EPOCH):
            
            batch_loss, batch_count = 0, 0
            
            for idx, batch in enumerate(train_dataloader):
                
                batch = batch.to(device)
                
                output = model(batch)
                
                loss = loss_fn(output,batch)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                batch_count += 1
                tt.update()
                
            batch_loss = batch_loss / batch_count
            print(f'{batch_loss}')

            plot_results(batch, recons, save_path, epoch, idx)

In [15]:
def plot_results(imgs, recons, save_path, epoch, idx):

    bs = imgs.size(0)
    fig, axes = plt.subplots(nrows=2,ncols=bs,figsize=(bs*4,20))

    for i, (row,col) in enumerate(product(range(2),range(bs))):

        if row == 0:
            axes[row][col].imshow(np.transpose(imgs[col].detach().cpu().numpy(),(1,2,0)))
            axes[row][col].set_xlabel(f'Current Degree: {int(current_deg[col])}',fontsize=15,fontweight='bold')
            if col == 0:
                axes[row][col].set_ylabel('Original Image',fontsize=15,fontweight='bold')
        
        elif row == 1:
            axes[row][col].imshow(np.transpose(generated_imgs[col].detach().cpu().numpy(),(1,2,0)))

            if col == 0:
                axes[row][col].set_ylabel('Reconstructed Image',fontsize=15,fontweight='bold')

            
        axes[row][col].set_yticks([])
        axes[row][col].set_xticks([])

    plt.subplots_adjust(wspace=0,hspace=0)
    plt.savefig(os.path.join(save_path,f'fig_{epoch}_{idx}.png'),format='png',bbox_inches='tight',pad_inches=0,dpi=100, bbox_extra_artists=(suptitle,))
    plt.show() 


In [None]:
# def init_wandb(args):

#     wandb.init(
#     mode =   'online'if args.wandb else 'disabled',
#     project = 'exoplanet_detection',
#     entity  = '',
#     group   = 'fermat-direction-gen',
#     name    = f'{args.device}-fermat',
#     config = vars(args)
# )