In [0]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline  
from google.colab import drive
drive.mount('/content/drive')
import torch
from torch import nn, optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
import cv2
import glob

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
#change the URL if necessary; might take a while to get the images so pickle them afterwards (much faster to retrieve)

#images = [cv2.imread(file) for file in glob.glob("drive/My Drive/Parasitized/*.png")]
#images2 = [cv2.imread(file) for file in glob.glob("drive/My Drive/results/Uninfected/*.png")]

In [0]:
import pickle
#with open('drive/My Drive/images.pkl','wb') as f:
#  pickle.dump(images,f)
with open('drive/My Drive/images.pkl','rb') as f:
  images=pickle.load(f)

In [0]:
import pickle
#with open('drive/My Drive/images2.pkl','wb') as f:
#  pickle.dump(images2,f)
with open('drive/My Drive/images2.pkl','rb') as f:
  images2=pickle.load(f)

In [0]:
import torch.utils.data as data

class ImageFileset(data.Dataset):
  def __init__(self, flist, transform=None):
    self.imlist = flist
    self.transform = transform
    
  def __getitem__(self, index):
    img = self.imlist[index]
    if self.transform is not None:
      img = self.transform(img)
      
    return img
  
  def __len__(self):
    return len(self.imlist)

In [0]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(size=(128,128)),
    transforms.ToTensor(),
    transforms.Normalize((0.42983678, 0.39725652, 0.52416897), (0.28184757, 0.26598915, 0.34225056))
])

trainset=ImageFileset(images,transform=transform)
train_loader = torch.utils.data.DataLoader(trainset,batch_size=128, shuffle=True)

trainset_control=ImageFileset(images2,transform=transform)
train_loader_control = torch.utils.data.DataLoader(trainset_control,batch_size=128, shuffle=True)

In [0]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        #3*128*128=49152
        self.fc1 = nn.Linear(49152, 400)
        self.fc1_bn = nn.BatchNorm1d(400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc3_bn = nn.BatchNorm1d(400)
        self.fc4 = nn.Linear(400, 49152)

    def encode(self, x):
        h1 = F.relu(self.fc1_bn(self.fc1(x)))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3_bn(self.fc3(z)))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 49152))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [0]:
#Xavier initialization
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
        
import math  
#Kaiming initialization
def kaiming_init(m):
    if type(m) == nn.Linear:
      weights,bias = m.named_parameters()
      m.weight = torch.nn.Parameter(torch.randn(weights[1].shape[0],weights[1].shape[1])*math.sqrt(2./weights[1].shape[0]))
      m.bias.data.fill_(0)
        
model=VAE()
model.apply(kaiming_init)
model.cuda()
optimizer=optim.Adam(model.parameters(),lr=1e-4)

In [0]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, epoch):
    #binary cross-entropy
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 49152))
    #Kullback-Leibler divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + 0.1*KLD

def adjust_learning_rate(optimizer, epoch,lr):
    """Sets the learning rate to the initial LR decayed by 2 every 1 epochs"""
    lr = lr * (0.5 ** (epoch // 1))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [0]:
def train(epoch):
    model.train()
    train_loss = 0
    #change to train_loader_control for noninfected cells
    for batch_idx, val in enumerate(train_loader):
        adjust_learning_rate(optimizer,epoch-1,optimizer.param_groups[0]['lr'])
        val = val.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(val)
        loss = loss_function(recon_batch, val, mu, logvar, epoch)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 5 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(val), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(val)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

In [0]:
for epoch in range(1, 11):
        train(epoch)
        if epoch % 2 == 0:
          with torch.no_grad():
              sample = torch.randn(64, 20).cuda()
              sample = model.decode(sample)
              save_image(sample.view(64, 3, 128, 128),
                'drive/My Drive/results/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 0.3637
====> Epoch: 2 Average loss: 0.2702
====> Epoch: 3 Average loss: 0.2182
====> Epoch: 4 Average loss: 0.2461
====> Epoch: 5 Average loss: 0.2845
====> Epoch: 6 Average loss: 0.2457
====> Epoch: 7 Average loss: 0.2243
====> Epoch: 8 Average loss: 0.2271
====> Epoch: 9 Average loss: 0.2179
====> Epoch: 10 Average loss: 0.2091
