In [1]:
import numpy as np
import torch
import torchvision
import pickle
import matplotlib.pyplot as plt
from dataloader import *

In [2]:
transform = transforms.Compose([Resize()])

In [3]:
data = np.load('./data/velocity64.pickle',allow_pickle=True)


In [4]:
#Normalize the data
data[:,:,:,0] = (data[:,:,:,0] - np.min(data[:,:,:,0]))/(np.max(data[:,:,:,0])-np.min(data[:,:,:,0]))
data[:,:,:,1] = (data[:,:,:,1] - np.min(data[:,:,:,1]))/(np.max(data[:,:,:,1]) - np.min(data[:,:,:,1]))

In [5]:
data = SmokeDataset(data=data,transform=transform)

In [6]:
for i,inputs in enumerate(data):
    largeImage=inputs[0]
    smallImage=inputs[1]
    print(largeImage.shape, smallImage.shape)
    break
    

(2, 64, 64) (2, 8, 8)


In [7]:
#visualise the data from 
import scipy.misc

def velocityFieldToPng(frameArray):
    """ Returns an array that can be saved as png with scipy.misc.toimage
    from a velocityField with shape [height, width, 2]."""
    outputframeArray = np.zeros((frameArray.shape[0], frameArray.shape[1], 3))
    for x in range(frameArray.shape[0]):
        for y in range(frameArray.shape[1]):
            # values above/below 1/-1 will be truncated by scipy
            frameArray[y][x] = (frameArray[y][x] * 0.5) + 0.5
            outputframeArray[y][x][0] = frameArray[y][x][0]
            outputframeArray[y][x][1] = frameArray[y][x][1]
    return outputframeArray

# Write the GAN Architecture


In [8]:
import torch.nn as nn
import torch.nn.functional as F

In [9]:
kernel = 2
stride = 1

In [10]:
class Discriminator(nn.Module):
    def __init__(self,conv_dim=2):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(conv_dim,conv_dim*2,kernel_size=kernel,stride=stride)
        self.conv2 = nn.Conv2d(conv_dim*2,conv_dim*4,kernel,stride)
        self.conv3 = nn.Conv2d(conv_dim*4,conv_dim*8,kernel,stride)
        self.conv4 = nn.Conv2d(conv_dim*8,conv_dim*16,kernel,stride)
        
    def forward(self,x):
        x = F.relu(self.conv1(x))
#         x = F.batch_norm(x)
        x = F.relu(self.conv2(x))
#         x = F.batch_norm(x)
        x = F.relu(self.conv3(x))
#         x = F.batch_norm(x)
        x = F.sigmoid(self.conv4(x))
        return x

In [11]:
Discriminator()

Discriminator(
  (conv1): Conv2d(2, 4, kernel_size=(2, 2), stride=(1, 1))
  (conv2): Conv2d(4, 8, kernel_size=(2, 2), stride=(1, 1))
  (conv3): Conv2d(8, 16, kernel_size=(2, 2), stride=(1, 1))
  (conv4): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1))
)

In [12]:
class Generator(nn.Module):
    def __init__(self,deconv_dim=2):
        super(Generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(deconv_dim,deconv_dim*2,3,stride=2)
        self.deconv2 = nn.ConvTranspose2d(deconv_dim*2,deconv_dim*4,3,stride=2)
        self.deconv3 = nn.ConvTranspose2d(deconv_dim*4,deconv_dim*8,3,stride =2)
        self.deconv4 = nn.ConvTranspose2d(deconv_dim*8,deconv_dim,3,stride=2)
    
    def forward(self,x):
        
        x = F.relu(self.deconv1(x))
#         x = F.batch_norm(x)
        x = F.relu(self.deconv2(x))
#         x = F.batch_norm(x)
        x = F.relu(self.deconv3(x))
#         x = F.batch_norm(x)
        x = F.sigmoid(self.deconv4(x))
        
        return x

In [13]:
Generator()

Generator(
  (deconv1): ConvTranspose2d(2, 4, kernel_size=(3, 3), stride=(2, 2))
  (deconv2): ConvTranspose2d(4, 8, kernel_size=(3, 3), stride=(2, 2))
  (deconv3): ConvTranspose2d(8, 16, kernel_size=(3, 3), stride=(2, 2))
  (deconv4): ConvTranspose2d(16, 2, kernel_size=(3, 3), stride=(2, 2))
)

In [14]:
params = {'batch_size': 10,
          'shuffle': True,
          'num_workers': 6}
max_epochs = 1

In [15]:
dataset = torch.utils.data.DataLoader(dataset=data,**params)

In [16]:
dataiter = iter(dataset)
Dinput,Ginput = dataiter.next()
print(Dinput.shape,Ginput.shape)


torch.Size([10, 2, 64, 64]) torch.Size([10, 2, 8, 8])


In [17]:
D = Discriminator()
G = Generator()

In [18]:
D(Dinput)
G(Ginput)

tensor([[[[0.5332, 0.5400, 0.5368,  ..., 0.5384, 0.5393, 0.5414],
          [0.5328, 0.5344, 0.5271,  ..., 0.5306, 0.5338, 0.5316],
          [0.5373, 0.5427, 0.5460,  ..., 0.5430, 0.5411, 0.5433],
          ...,
          [0.5383, 0.5410, 0.5459,  ..., 0.5433, 0.5414, 0.5440],
          [0.5346, 0.5336, 0.5292,  ..., 0.5274, 0.5348, 0.5317],
          [0.5418, 0.5374, 0.5456,  ..., 0.5445, 0.5382, 0.5401]],

         [[0.4805, 0.4876, 0.4725,  ..., 0.4708, 0.4869, 0.4659],
          [0.4769, 0.4786, 0.4744,  ..., 0.4735, 0.4789, 0.4725],
          [0.4788, 0.4913, 0.4555,  ..., 0.4587, 0.4908, 0.4572],
          ...,
          [0.4753, 0.4907, 0.4576,  ..., 0.4586, 0.4905, 0.4571],
          [0.4758, 0.4776, 0.4759,  ..., 0.4737, 0.4789, 0.4735],
          [0.4716, 0.4774, 0.4633,  ..., 0.4647, 0.4785, 0.4655]]],


        [[[0.5332, 0.5400, 0.5367,  ..., 0.5384, 0.5393, 0.5414],
          [0.5328, 0.5344, 0.5271,  ..., 0.5306, 0.5338, 0.5316],
          [0.5373, 0.5427, 0.5460,  ...,