In [26]:
global ROOT_DIR
ROOT_DIR = '/gpfs/commons/groups/gursoy_lab/aelhussein/ot_cost/otcost_fl_rebase'
global DATA_DIR
DATA_DIR = f'{ROOT_DIR}/data/ISIC'

import sys
import json
import torch
sys.path.append(f'{ROOT_DIR}/code/ISIC/')
import torch.nn.functional as F
import torch.nn as nn
import dataset
sys.path.append(f'{ROOT_DIR}/code/ISIC/efficientnet_ae')
from torch.utils.data import DataLoader as dl
from torch.optim.lr_scheduler import ExponentialLR
import copy

In [2]:
BATCH_SIZE = 1
LR = 5e-2
EPOCHS = 1000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
train_data = dataset.FedIsic2019(train=True, pooled = True, data_path=DATA_DIR)
train_loader = dl(train_data, batch_size = BATCH_SIZE, shuffle = True)

val_data = dataset.FedIsic2019(train=False, pooled = True, data_path=DATA_DIR)
val_loader = dl(val_data, batch_size = BATCH_SIZE, shuffle = True)

In [10]:
class Autoencoder(nn.Module):
    def __init__(self, n_emb):
        super(Autoencoder, self).__init__()
        self.n_emb = n_emb

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.LeakyReLU(0.1), nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, 3, padding=1), nn.LeakyReLU(0.1), nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2),  
            nn.Conv2d(32, 64, 3, padding=1), nn.LeakyReLU(0.1), nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 3, padding=1), nn.LeakyReLU(0.1), nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2), 
            nn.Conv2d(128, 64, 3, padding=1), nn.LeakyReLU(0.1), nn.BatchNorm2d(64),
            nn.Conv2d(64, 32, 3, padding=1), nn.LeakyReLU(0.1), nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding=1), nn.LeakyReLU(0.1), nn.BatchNorm2d(32),
)
        
        self.bottleneck = nn.Sequential(
            nn.Linear(32 * 50 * 50, 32*25*25),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32*25*25, self.n_emb)
)

        self.expand = nn.Sequential(nn.Linear(self.n_emb, 32*50*50))

        self.decoder = nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1),        
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), 
            nn.ReLU(),
            nn.ConvTranspose2d(16, 16, 4, stride=2, padding=1), 
            nn.ReLU(),
            nn.Conv2d(16, 3, 3, padding=1),    
            nn.Sigmoid()
)
    

    def forward(self, x, get_embedding=False):
        x = self.encoder(x)
        x_flat = x.view(x.size(0), -1)
        embedding = self.bottleneck(x_flat)
        if get_embedding:
            return embedding
        x = self.expand(embedding)
        x = x.view(x.size(0), 32, 50, 50)
        x = self.decoder(x)
        return x


In [5]:
X, y = next(iter(train_loader))
X = X.transpose(2,1)

In [27]:
criterion = nn.MSELoss()

In [12]:
n_emb = 1000
model = Autoencoder(n_emb)

In [13]:
embedding_point = model(X)

In [28]:
criterion(embedding_point, X)

tensor(0.0347, grad_fn=<MseLossBackward0>)