In [67]:
import torch
import torch.nn as nn 
import torch.nn.functional as f
import numpy as np

class UNET(nn.Module):
    """
    Description: 
    Inputs:
        - Input data (spectrograms)
    Outputs:
        - Reconstructed data
        - Latent space data
    """
    def __init__(self):
        super(UNET, self).__init__()
        
        #down
        self.c1 = nn.Conv2d(1, 8, kernel_size=(3,3), stride=(2,2), padding=(1,1)) 
        self.c2 = nn.Conv2d(8, 16, kernel_size=(3,3), stride=(2,2), padding=(1,1))
        self.c3 = nn.Conv2d(16, 32, kernel_size=(3,3), stride=(2,2), padding=(1,1))
        self.c4 = nn.Conv2d(32, 64, kernel_size=(3,3), stride=(2,2), padding=(0,1))
        self.c5 = nn.Conv2d(64, 128, kernel_size=(3,3), stride=(2,2), padding=(1,0))
        
        #latent
        
        self.l1 = nn.Linear(1152,9)
        self.f1 = nn.Flatten()
        
        #up
        self.l2=nn.Linear(9,1152)
        #up
        self.c6=nn.ConvTranspose2d(128, 64, kernel_size=(3,3), stride=(2,2), padding=(1,0))  
        self.c7=nn.ConvTranspose2d(64, 32, kernel_size=(3,3), stride=(2,2), padding=(0,1))      
        self.c8=nn.ConvTranspose2d(32, 16, kernel_size=(3,3), stride=(2,2), padding=(1,1))
        self.c9=nn.ConvTranspose2d(16, 8, kernel_size=(3,3), stride=(2,2), padding=(0,1))
        self.c10=nn.ConvTranspose2d(8, 1, kernel_size=(3,3), stride=(2,2), padding=(0,1))

    def forward(self, x):
        
        #down 
        x1 = f.relu(self.c1(x))
        x2 = f.relu(self.c2(x1))
        x3 = f.relu(self.c3(x2))
        x4 = f.relu(self.c4(x3))
        x5 = f.relu(self.c5(x4))
        
        
        x_latent=f.relu(self.l1(self.f1(x5)))
        x_unlatent=f.relu(self.l2(x_latent))
        #up
        x_unlatent =  x_unlatent.view(-1, 128, 3, 3)
        
        #up
        x6= f.relu(self.c6(x_unlatent))
        

        print(x4.shape)
        print(x6.shape)
        #print(tf.size(x4))
        
        u6=torch.cat([x6, x4])
        x7= f.relu(self.c7(u6))
        
        print(x3.shape)
        print(x7.shape)
        u7=torch.cat([x7, x3])
        x8= f.relu(self.c8(u7))
        
        print(x2.shape)
        print(x8.shape)
        u8=torch.cat([x8, x2[:,:,:21,:]])
        x9= f.relu(self.c9(u8))
        
        print(x1.shape)
        print(x9.shape)
        u9=torch.cat([x9, x1[:,:,:43,:49]])
        #out=self.c10(u9)[:,:,4:-4,1:]
        
        return out

In [49]:
%reload_ext autoreload
%autoreload 2
from data import *
from train import *
from networks import *
from visualisation import *

VALIDATION_DATASET_PERC = 0.2
TEST_DATASET_PERC = 0.1

dataset = H5SeismicDataset(
    filepath="/datasets/ee228-sp21-A00-public/RIS_Seismic-001.h5",
    transform = transforms.Compose(
        [SpecgramShaper(), SpecgramToTensor()]
    )
)
#dataset[0]
dataloaders = getDataloaderSplit(dataset, VALIDATION_DATASET_PERC,TEST_DATASET_PERC)

In [68]:
if not torch.cuda.is_available():
    print('No cuda')
#params= {'lr':0.01,'epochs':50}
device = torch.device("cuda")
model= UNET()
PATH_WEIGHTS='./saved_models/UNET_weights'
torch.save(model.state_dict(),PATH_WEIGHTS)
model.to(device)
outputs = pretrain(model= model,dataloaders=dataloaders,device=device)
PATH='./saved_models/UNET'
torch.save(model, PATH)

torch.Size([512, 64, 5, 7])
torch.Size([512, 64, 5, 7])
torch.Size([512, 32, 11, 13])
torch.Size([1024, 32, 11, 13])
torch.Size([512, 16, 22, 25])
torch.Size([1536, 16, 21, 25])
torch.Size([512, 8, 44, 50])
torch.Size([2048, 8, 43, 49])


ValueError: too many values to unpack (expected 2)