In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

import sys
import import_ipynb
from pathlib import Path

dir = Path('notebooks')
sys.path.insert(0, str(dir.resolve()))
import wgan
import globals
import ebm
import ddm

import matplotlib.pyplot as plt
import numpy as np

import os
from PIL import Image


In [None]:
class LoadDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))]
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
transform = transforms.Compose([
    transforms.Resize([globals.IMAGE_SIZE, globals.IMAGE_SIZE]), # Resizing images to 64 x 64
    transforms.ToTensor(), # Converting images to tensors
])
traindataset = LoadDataset(root_dir = '../data/SARscope/train/', transform = transform)
testdataset = LoadDataset(root_dir = '../data/SARscope/test/', transform = transform)
validdataset = LoadDataset(root_dir = '../data/SARscope/valid/', transform = transform)


trainloader = DataLoader(dataset = traindataset, batch_size = globals.BATCH_SIZE, shuffle = True)
testloader = DataLoader(dataset = testdataset, batch_size = globals.BATCH_SIZE, shuffle = True)
validloader = DataLoader(dataset = validdataset, batch_size = globals.BATCH_SIZE, shuffle = True)

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()



dataiter = iter(trainloader)
images = next(dataiter)

imshow(torchvision.utils.make_grid(images))
print(images.shape)

In [None]:
print(images)

In [None]:
def show_generated_images(generator, epoch, num_images = 1):
    with torch.no_grad():
        z = torch.randn(size = (num_images, globals.Z_DIM))
        generated_images = generator(z)
    generated_images = (generated_images + 1) / 2
    generated_images = torch.squeeze(generated_images,0).numpy().transpose(1,2,0)
    plt.imshow(generated_images)
    plt.show()
   

In [None]:
EBM = True
DDM = False

if EBM:
    model = ebm.EBM()
elif DDM:
    model = ddm.DiffusionModel()
else:
    generator = wgan.Generator()
    critic = wgan.Critic()
    critic_optimizer = torch.optim.Adam(params = critic.parameters(), lr = 0.0001)
    generator_optimizer = torch.optim.Adam(params = generator.parameters(), lr = 0.0001)
    model = wgan.wgan(generator, critic, critic_optimizer, generator_optimizer )

In [None]:

def train_one_epoch():


    for i, data in enumerate(trainloader):
        print('Batch number {}'.format(i))

        loss = model(data)

    return loss


In [None]:
epoch_number = 0

for epoch in range(globals.EPOCHS):
    print('EPOCH {}'.format(epoch_number + 1))

    model.train(True)
    avg_loss = train_one_epoch()


    model.eval()
    epoch_number += 1