In [1]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
import torchvision
from torchvision import transforms
import numpy as np
from tqdm import tqdm
from syndata import get_exo_locations
import pandas as pd
from get_syndata import *
from models import Exocoder,vae_loss

In [2]:
class SyntheticDataset(Dataset):
    
    def __init__(self,star,exo,transform=True,labelled=True):
            
        self.star      = star
        self.exo       = exo
        self.transform = transform 
        self.labelled  = labelled
        
        if labelled:
        
            self.arr      = np.concatenate((exo[0],star[0]))
            self.label    = torch.vstack((exo[1],star[1]))
        
            self.data = [self.arr,self.label]
            
        else:
            
            self.data      = np.concatenate((exo,star))
            
        
    def __len__(self):
        
        return len(self.arr) if self.labelled  else len(self.data)
    
    def __getitem__(self,idx):
        
    
        if self.labelled:
            sample = self.data[0][idx]
            
            if self.transform:
                sample = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.CenterCrop(160),
                       ])(sample)
                
            label  = self.data[1][idx]
            return [sample,label]
        
        else:
            sample = self.data[idx]
            
            if self.transform:
                sample = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.CenterCrop(160),
                       ])(sample)
            
            return sample
        

In [3]:
train_star, train_exo, test_star, test_exo = get_train_test()

Exo data: (12300, 320, 320)
Star data: (8064, 320, 320)


In [4]:
train_dataset = SyntheticDataset(star=train_star,exo=train_exo,labelled=False)
test_dataset = SyntheticDataset(star=test_star,exo=test_exo,labelled=False)

In [5]:
train_dataloader = DataLoader(train_dataset,batch_size=512,shuffle=True)
test_dataloader = DataLoader(test_dataset,batch_size=512,shuffle=True)

In [6]:
_,train_samples = next(enumerate(train_dataloader))
_,test_sample = next(enumerate(test_dataloader))

In [7]:
train_stack = torch.concat((train_samples[0],train_samples[0]))
test_stack  = torch.concat((test_sample[0],test_sample[0]))
for i in range(2,25):
    
    train_stack = torch.concat((train_stack,train_samples[i]),axis=0)
    test_stack = torch.concat((test_stack,test_sample[i]),axis=0)


In [8]:
#visualize_data(train_stack)

In [9]:
def calculate_conv_dims(input_size,paddings:list,kernels:list,strides:list,maxpool:list):
    
    outputs = []
    outputs.append(input_size)
    for i in range(len(paddings)):
        
        output_size = (input_size + (2*paddings[i]) - (kernels[i] - 1) - 1)/strides[i] + 1
        if maxpool[i] != 0:
            print('0')
            output_size = (output_size  + (2*paddings[i]) - (maxpool[i]-1)-1)/2 +1
        
        outputs.append(int(output_size))
        input_size = output_size
        
    print(outputs)
    return outputs

In [10]:
def calculate_convtrans_dim(input_size,paddings:list,kernels:list,strides:list):
    outputs = []
    outputs.append(input_size)
    for i in range(len(paddings)):
        
        output_size = (input_size - 1) * strides[i]  -  2 * paddings[i] + kernels[i] - 1 + 1
        outputs.append(int(output_size))
        input_size = output_size
        
    print(outputs)
    return outputs

In [11]:
kernels_forward = [3,3,3,3,3,3,3,3,3]
paddings_forward= [0,0,0,0,0,0,0,0,0]
strides_forward = [1,1,1,2,1,1,1,2,1]
maxpool = [0,0,0,0,0,0,0,0,0]

In [12]:
kernels_backward = [3,3,3,3,3,3,2]
paddings_backward= [0,0,0,0,0,0,0]
strides_backward = [1,1,1,2,1,2,1]

In [13]:
convdim_outputs = calculate_conv_dims(160,paddings_forward,kernels_forward,strides_forward,maxpool)

[160, 158, 156, 154, 76, 74, 72, 70, 34, 32]


In [14]:
convtrans_outputs = calculate_convtrans_dim(32,paddings_backward,kernels_backward,strides_backward)

[32, 34, 36, 38, 77, 79, 159, 160]


In [40]:
class Exocoder(nn.Module):

    def __init__(self,convdim_outputs_e:list,kernels_e:list,strides_e:list,convdim_outputs_d:list,kernels_d:list,strides_d:list,latent_dim:int):

        super(Exocoder,self).__init__()

        self.convdim           = convdim_outputs_e
        self.kernels           = kernels_e
        self.strides           = strides_e
        self.convtranspose     = convdim_outputs_d
        self.kernelsd          = kernels_d
        self.stridesd          = strides_d
        self.latent_dim        = latent_dim

        self.C                 = 8


        self.encoder = nn.Sequential(

            nn.Conv2d(in_channels=1,out_channels=self.C,stride=self.strides[0],kernel_size=self.kernels[0]), #1
            nn.ReLU(),
            
            nn.Conv2d(in_channels=self.C,out_channels=self.C,stride=self.strides[1],kernel_size=self.kernels[1]), #2
            nn.ReLU(),
            
            nn.Conv2d(in_channels=self.C,out_channels=self.C,stride=self.strides[2],kernel_size=self.kernels[2]), #3
            nn.ReLU(),
            
            nn.Conv2d(in_channels=self.C,out_channels=self.C*2,stride=self.strides[3],kernel_size=self.kernels[3]), #4 
            nn.ReLU(),
            
            nn.Conv2d(in_channels=self.C*2,out_channels=self.C*2,stride=self.strides[4],kernel_size=self.kernels[4]), #5
            nn.ReLU(),
            
            nn.Conv2d(in_channels=self.C*2,out_channels=self.C*2,stride=self.strides[5],kernel_size=self.kernels[5]), #6
            nn.ReLU(),
            
            nn.Conv2d(in_channels=self.C*2,out_channels=self.C*2,stride=self.strides[6],kernel_size=self.kernels[6]), #7
            nn.ReLU(),
            
            nn.Conv2d(in_channels=self.C*2,out_channels=self.C*2,stride=self.strides[7],kernel_size=self.kernels[7]), #8
            nn.ReLU(),
            
            nn.Conv2d(in_channels=self.C*2,out_channels=self.C*2,stride=self.strides[8],kernel_size=self.kernels[8]), #9
            nn.ReLU(),

            nn.Flatten(),

            nn.Linear((self.C*2)*self.convdim[-1]**2,self.latent_dim),
            nn.ReLU(),

        )

        self.mean   = nn.Linear(self.latent_dim,(self.C*2)*self.convdim[-1]**2)
        self.logvar = nn.Linear(self.latent_dim,(self.C*2)*self.convdim[-1]**2)

        self.decoder = nn.Sequential(

            nn.Unflatten(1,(self.C*2,self.convdim[-1],self.convdim[-1])),

            nn.ConvTranspose2d(in_channels=self.C*2, out_channels=self.C*2, kernel_size=self.kernelsd[0], stride=self.stridesd[0]), #1
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=self.C*2, out_channels=self.C*2, kernel_size=self.kernelsd[1], stride=self.stridesd[1]), #2
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=self.C*2, out_channels=self.C*2, kernel_size=self.kernelsd[2], stride=self.stridesd[2]), #3
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=self.C*2, out_channels=self.C*2, kernel_size=self.kernelsd[3], stride=self.stridesd[3]), #4
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=self.C*2, out_channels=self.C*2, kernel_size=self.kernelsd[4], stride=self.stridesd[4]), #5
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=self.C*2, out_channels=self.C, kernel_size=self.kernelsd[5], stride=self.stridesd[5]), #6
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=self.C, out_channels=1, kernel_size=self.kernelsd[6], stride=self.stridesd[6]), #7
            nn.ReLU(),
        )

    def reparametrize(self,mean,logvar):

        eps = torch.randn(mean.size(0),mean.size(1)).to(device)
        z = mean + torch.exp(logvar/2) * eps
        
        return z

    def forward(self,x):

        x = self.encoder(x)

        mean = self.mean(x)
        logv = self.logvar(x)

        z = self.reparametrize(mean,logv)

        x_recon = self.decoder(z)

        return x_recon, z, mean, logv
    

In [16]:
model = Exocoder(convdim_outputs_e=convdim_outputs,
                 kernels_e=kernels_forward,
                 strides_e=strides_forward,
                 convdim_outputs_d=convtrans_outputs,
                 kernels_d=kernels_backward,
                 strides_d=strides_backward,
                 latent_dim=500)

In [17]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
model = model.to(device)
#model = model.to(f'cuda:{model.device_ids[0]}')

In [18]:
def train(model,train_dataloader,optimizer,device,loss_fn,EPOCH=30):
    
    with tqdm(total = len(train_dataloader) * EPOCH) as tt:
        
        model.train()
        
        for epoch in range(EPOCH):
            
            total_loss, batch_count = 0, 0
            
            for idx, batch in enumerate(train_dataloader):
                
                batch = batch.float().to(device)
                
                x_recon, z, mean, logv = model(batch)
                
                loss = loss_fn(x_recon,batch,mean,logv)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                
                total_loss += loss.item()
                batch_count += 1
                tt.update()
                
            total_loss = total_loss / batch_count
            print(f'{total_loss}')

In [19]:
def vae_loss(x_recon,x,mean,logv):

    recons = F.mse_loss(x_recon, x, reduction='sum')
    kl = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
    kl /= 512 * 160 * 160
    
    return recons + kl 

In [20]:
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)  

In [21]:
train(model=model,train_dataloader=train_dataloader,optimizer=optimizer,loss_fn=vae_loss,device=device,EPOCH=20)

  5%|███▊                                                                        | 36/720 [00:18<05:00,  2.27it/s]

548280817891.55554


 10%|███████▌                                                                    | 72/720 [00:33<04:12,  2.56it/s]

521404131100.44446


 15%|███████████▎                                                               | 108/720 [00:50<04:30,  2.27it/s]

376550740423.1111


 20%|███████████████                                                            | 144/720 [01:04<04:30,  2.13it/s]

300950278144.0


 25%|██████████████████▊                                                        | 180/720 [01:22<04:09,  2.17it/s]

274222328945.77777


 30%|██████████████████████▌                                                    | 216/720 [01:37<03:50,  2.19it/s]

253728578218.66666


 35%|██████████████████████████▎                                                | 252/720 [01:53<02:29,  3.12it/s]

240802850588.44446


 40%|██████████████████████████████                                             | 288/720 [02:09<03:12,  2.24it/s]

236366879857.77777


 45%|█████████████████████████████████▊                                         | 324/720 [02:24<02:20,  2.81it/s]

233384838485.33334


 50%|█████████████████████████████████████▌                                     | 360/720 [02:41<02:47,  2.15it/s]

231246154865.77777


 55%|█████████████████████████████████████████▎                                 | 396/720 [02:56<02:17,  2.35it/s]

229068593379.55554


 60%|█████████████████████████████████████████████                              | 432/720 [03:13<02:11,  2.19it/s]

227205925091.55554


 65%|████████████████████████████████████████████████▊                          | 468/720 [03:28<01:53,  2.22it/s]

225486291399.1111


 70%|████████████████████████████████████████████████████▌                      | 504/720 [03:45<01:13,  2.95it/s]

223836851768.8889


 75%|████████████████████████████████████████████████████████▎                  | 540/720 [04:00<01:20,  2.24it/s]

222467168483.55554


 80%|████████████████████████████████████████████████████████████               | 576/720 [04:15<00:38,  3.74it/s]

221145360611.55554


 85%|███████████████████████████████████████████████████████████████▊           | 612/720 [04:31<00:46,  2.33it/s]

219934900679.1111


 90%|███████████████████████████████████████████████████████████████████▌       | 648/720 [04:45<00:20,  3.58it/s]

218720233699.55554


 95%|███████████████████████████████████████████████████████████████████████▎   | 684/720 [05:01<00:16,  2.23it/s]

217575862272.0


100%|███████████████████████████████████████████████████████████████████████████| 720/720 [05:16<00:00,  2.28it/s]

216318269212.44446





In [37]:
sample = torch.randn((100,16*32*32)).float().to(device)

In [39]:
model.decoder(sample).shape

torch.Size([100, 8, 160, 160])