In [9]:
import os
from os import path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import datetime
import socket
import random
import easydict
from pytz import timezone
import operator
import pandas as pd

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from IPython.display import Image

%matplotlib inline


In [54]:
tz = timezone('US/Eastern')

path = "/Users/alimehdi/Downloads/data/" # data directory

# train data
train_wo_file = "train_without_missing.txt"
train_w_file = "train_with_missing.txt"
train_meta = "train.meta"

# validation data
validation_wo_file = "validation_without_missing.txt"
validation_w_file = "validation_with_missing.txt"
validation_meta = "validation.meta"

# test data
test_wo_file = "test_without_missing.txt"
test_w_file = "test_with_missing.txt"
test_meta = "test.meta"

# checkpoint_epoch = 10
checkpoint = None

args = easydict.EasyDict({
    "log_name": "imputation_model",
    "model_type": "resnet2d",
    "nb_epochs":100,
    "bottleneck_dim":1024,
    "encoder_dim":6,
    "dencoder_dim":6,
    "encoder_dim":0,
    "dencoder_dim":0,
    "batch_size":512,
    "seq_len":199,
    "learning_rate":0.0001,
    "betal":0.5,
    "seed":1,
    "missing_char":".",
    "resnet_mix":0.5,
    "optimizer":"adam",
    "haplotype_cutoff":1.0,
    "filter_size":7,
    "with_region_metadata":True,
    "continent": False,
    "save_model": True,
    "save_embedded_haplotypes": True,
    "train": True
})

os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = args.seed
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

charmap = {"A":0, "C":1, "G":2, "T":3, args.missing_char:4}
charmap_rev = {0:"A", 1:"C", 2:"G", 3:"T"}





Collect metadata info

In [55]:
df_train = pd.read_csv(path + train_meta, sep = '\t',header=None)
df_validation = pd.read_csv(path +validation_meta, sep='\t', header=None)
df_test = pd.read_csv(path+ test_meta, sep='\t', header=None)

In [56]:
df_train=df_train.fillna("")
df_validation=df_validation.fillna("")
df_test=df_test.fillna("")

In [57]:
df_validation.head()

Unnamed: 0,0,1,2
0,SRR11593361,Germany,North Rhine Westphalia
1,ERR4241279,United Kingdom,Scotland
2,ERR4243656,United Kingdom,England
3,ERR4307646,United Kingdom,Wales
4,SRR12188727,USA,Washington


In [58]:
countries = {region:i for i, region in enumerate(sorted(set(df_train[1])|set(df_validation[1])|set(df_test[1])))}
countries['NaN'] = len(countries) 
print(countries)

{'': 0, 'Australia': 1, 'China': 2, 'Egypt': 3, 'Germany': 4, 'India': 5, 'Israel': 6, 'Japan': 7, 'Morocco': 8, 'South Africa': 9, 'South Korea': 10, 'Spain': 11, 'Timor-Leste': 12, 'Turkey': 13, 'USA': 14, 'United Kingdom': 15, 'NaN': 16}


In [59]:
if args.continent:
    meta_region = continents
else:
    meta_region = countries

In [60]:
def charmap_to_onehots(chars, charmap, model_type="resnet2d"):
    onehots = np.int_(np.zeros([len(chars),len(charmap)]))
    I = np.eye(len(charmap))
    for i, char in enumerate(chars):
        onehots[i] = I[charmap[char]]
        
    if model_type == "mlp":
        onehots = np.reshape(onehots, [-1])
    if model_type == "resnet1d" or model_type == "resnet2d":
        onehots = onehots.T
    return onehots

def metadata_to_one_hot(data, nb_classes, model_type = "resnet2d"):
    targets = np.array(data.reshape(-1))
    return np.eye(nb_classes)[targets].astype(int)[0]

def listToDict(lst):
    op = {lst[i].strip(): lst[i+1].strip() for i in range(0,len(lst), 2)}
    return op
        
    
# for GANs

def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2, inplace=True)
    )

def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    return disc_loss

def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss

In [61]:
class Dataset(Dataset):
    def __init__(self, path, files, model, split="train", charmap=charmap, haplotype_cutoff = args.haplotype_cutoff,
                meta_continent = args.continent, meta_region = meta_region):
        
        self.dataset = {}
        for file in files[0]:
            with open(os.path.join(path,file),'r') as f:
                content = f.readlines()
                self.dataset = {**self.dataset, **listToDict(content)}
                
        self.dataset_new = {}
        for k, v in self.dataset.items():
            if '.' in v:
                if v.count('.')/len(v) <= args.haplotype_cutoff:
                    self.datset_new[k] = v
                
                else:
                    self.dataset_new[k] = v
            
            self.dataset = self.dataset_new
            
            self.sequences = [charmap_to_onehots(sequence, charmap, model) for sequence in list(self.datset.values())]
            self.ids = list(range(len(list(self.dataset))))
            
            index = 0
            if args.continent: index = 2
            else: index = 3
                
            meta_place = {}
            for file in files[1]:
                df = pd.read_csv(path + file, sep = '\t', lineterminator = '\n', header = None)
                meta_temp = {name:place for name, place in zip(list(df[0]), list(df[index]))}
                meta_place = {**meta_place, **meta_temp}
            
            self.meta_place = {}
            for i, name in enumerate(self.dataset):
                self.meta_place[i] = metadata_to_one_hot(meta_region[meta_place[name]], len(meta_region)+1)
                
        def __getitem__(self, i):
            id = self.ids[i]
            return self.sequences[id], self.meta_place[id]
        
        def __len__(self):
            return len(self.ids)

In [62]:
dataset_train = Dataset(path, [[train_wo_file], [train_meta]], args.model_type)
dataset_validation = Dataset(path, [[train_wo_file, validation_wo_file], [train_meta, validation_meta]], args.model_type)


IndexError: list index out of range

In [63]:
criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cpu'

dataloader_train = DataLoader(dataset_train, args.batch_size, drop_last = True, shuffle = True)
if args.train:
    dataloader_validation = DataLoader(dataset_validation, 1,drop_last=False,shuffle= False)

NameError: name 'dataset_train' is not defined

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=29903, im_dim=29903, hidden_dim=128):
        super(Generator, self).__init__()
        # Build the neural network
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
    def forward(self, noise):
        return self.gen(noise)
    
class Discriminator(nn.Module):
    def __init__(self, im_dim=29903, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, image):
        return self.disc(image)
    

In [None]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=args.learning_rate)
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=args.learning_rate)

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
for epoch in range(n_epochs):
  
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader1):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)

        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()

        # For testing purposes, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()

        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()
        # For testing purposes, to check that your code changes the generator weights
        if test_generator:
            assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1