In [1]:
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import sys
import argparse
import matplotlib.pyplot as plt
plt.rcParams["axes.grid"] = False
import matplotlib.image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import gzip
import struct
import array
from urllib.request import urlretrieve
from torch.distributions.multivariate_normal import MultivariateNormal
import torch
from torch.utils.data import Dataset

In [2]:
train_transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


Loading the data

In [None]:
dataset = datasets.ImageFolder('/kaggle/input/brian-tumor-dataset/Brain Tumor Data Set/Brain Tumor Data Set/', transform=train_transform)


In [None]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Apply validation transform to the validation dataset
val_dataset.dataset.transform = val_transform


In [None]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True, num_workers=4)


In [29]:
import torch
import torch.nn as nn
import numpy as np
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.optim import Adam
import torch.nn.functional as F

def array_to_image(array):
    return np.reshape(np.array(array), [28, 28])

def concat_images(images, row, col, padding=3):
    result = np.zeros((28 * row + (row - 1) * padding, 28 * col + (col - 1) * padding))
    for i in range(row):
        for j in range(col):
            result[i * 28 + (i * padding): i * 28 + (i * padding) + 28,
                   j * 28 + (j * padding): j * 28 + (j * padding) + 28] = images[i + j * row]
    return result

class Encoder(nn.Module):
    def __init__(self, latent_dimension, hidden_units, data_dimension):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(data_dimension, hidden_units)
        self.bn1 = nn.BatchNorm1d(hidden_units)
        self.fc2_mu = nn.Linear(hidden_units, latent_dimension)
        self.fc2_sigma = nn.Linear(hidden_units, latent_dimension)

    def forward(self, x):
        hidden = self.bn1(F.tanh(self.fc1(x)))
        mu = self.fc2_mu(hidden)
        log_sigma_square = self.fc2_sigma(hidden)
        sigma_square = torch.exp(torch.clamp(log_sigma_square, max=10))  # Clamping to avoid overflow
        return mu, sigma_square

class Decoder(nn.Module):
    def __init__(self, latent_dimension, data_dimension, hidden_units=500):
        super(Decoder, self).__init__()
        self.fc1_dec = nn.Linear(latent_dimension, hidden_units)
        self.fc2_dec = nn.Linear(hidden_units, data_dimension)

    def forward(self, z):
        hidden_dec = F.tanh(self.fc1_dec(z))
        p = torch.sigmoid(self.fc2_dec(hidden_dec))
        return p

class VAE(nn.Module):
    def __init__(self, args):
        super(VAE, self).__init__()
        self.encoder = Encoder(args.latent_dimension, args.hidden_units, args.data_dimension)
        self.decoder = Decoder(args.latent_dimension, args.data_dimension, args.hidden_units)
        self.e_path = args.e_path
        self.d_path = args.d_path
        if args.resume_training:
            self.load_state_dict(torch.load(self.e_path))

    @staticmethod
    def sample_diagonal_gaussian(mu, sigma_square):
        sigma = torch.sqrt(sigma_square)
        return mu + torch.randn_like(mu) * sigma

    @staticmethod
    def sample_Bernoulli(p):
        return torch.bernoulli(p)

    @staticmethod
    def logpdf_diagonal_gaussian(z, mu, sigma_square):
        sigma_square = torch.clamp(sigma_square, min=1e-6)
        covariance_matrix = torch.diag_embed(sigma_square)
        dist = MultivariateNormal(mu, covariance_matrix)
        return dist.log_prob(z)

    @staticmethod
    def logpdf_bernoulli(x, p):
        return (x * torch.log(p + 1e-8) + (1 - x) * torch.log(1 - p + 1e-8)).sum(dim=1)

    def elbo_loss(self, sampled_z, mu, sigma_square, x, p):
        log_q = self.logpdf_diagonal_gaussian(sampled_z, mu, sigma_square)
        z_mu = torch.zeros_like(mu)
        z_sigma = torch.ones_like(sigma_square)
        log_p_z = self.logpdf_diagonal_gaussian(sampled_z, z_mu, z_sigma)
        log_p = self.logpdf_bernoulli(x, p)
        return (log_p + log_p_z - log_q).mean()

    def train_model(self, train_loader, val_loader, epochs=200, save_interval=10):
        optimizer = Adam(self.parameters(), lr=0.001)
        for epoch in range(epochs):
            self.train()
            train_loss = 0
            for data, _ in train_loader:
                data = data.view(data.size(0), -1)
                optimizer.zero_grad()
                mu, sigma_square = self.encoder(data)
                zs = self.sample_diagonal_gaussian(mu, sigma_square)
                p = self.decoder(zs)
                loss = -self.elbo_loss(zs, mu, sigma_square, data, p)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
            train_loss /= len(train_loader.dataset)
            print(f'Epoch {epoch+1}, Average Training Loss: {train_loss:.4f}')
            if epoch % save_interval == 0:
                self.evaluate(val_loader)
                torch.save(self.state_dict(), f'model_epoch_{epoch}.pth')

    def evaluate(self, loader):
        self.eval()
        val_loss = 0
        with torch.no_grad():
            for data, _ in loader:
                data = data.view(data.size(0), -1)
                mu, sigma_square = self.encoder(data)
                zs = self.sample_diagonal_gaussian(mu, sigma_square)
                p = self.decoder(zs)
                loss = -self.elbo_loss(zs, mu, sigma_square, data, p)
                val_loss += loss.item()
        val_loss /= len(loader.dataset)
        print(f'Validation Loss: {val_loss:.4f}')


In [14]:
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert image to grayscale
    transforms.Resize((28, 28)),                  # Resize images to 28x28
    transforms.ToTensor(),                        # Convert images to PyTorch tensors
    transforms.Normalize((0.5,), (0.5,)),         # Normalize tensors
    transforms.Lambda(lambda x: x.view(-1))       # Flatten the tensors from [1, 28, 28] to [784]
])

# Load datasets
dataset_path = '/kaggle/input/brian-tumor-dataset/Brain Tumor Data Set/Brain Tumor Data Set/'
dataset = datasets.ImageFolder(root=dataset_path, transform=train_transform)

# Split dataset into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Define DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Optionally, check the output shape
images, labels = next(iter(train_loader))
print("Image batch shape:", images.shape)  # Expected output: Image batch shape: [64, 784]
print("Label batch shape:", labels.shape)  #


Image batch shape: torch.Size([64, 784])
Label batch shape: torch.Size([64])


In [5]:
    def logpdf_diagonal_gaussian(z, mu, sigma_square):
        # Input:
        #   z: sample [batch_size x latent_dimension]
        #   mu: mean of the gaussian distribution [batch_size x latent_dimension]
        #   sigma_square: variance of the gaussian distribution [batch_size x latent_dimension]
        # Output:
        #    logprob: log-probability of a diagomnal gaussian [batch_size]
        from torch.distributions.multivariate_normal import MultivariateNormal
        logprob = torch.zeros((mu.shape[0]))
        for i in range(mu.shape[0]):
            dist = MultivariateNormal(mu[i], sigma_square[i]*torch.eye((mu.shape[1])))
            logprob[i] = dist.log_prob(z[i])        
        return logprob

In [15]:
import argparse

def get_args():
    parser = argparse.ArgumentParser(description="VAE Model Configuration")
    parser.add_argument('--latent_dimension', type=int, default=2, help='Dimensionality of the latent space')
    parser.add_argument('--hidden_units', type=int, default=500, help='Number of units in the hidden layer of the encoder and decoder')
    parser.add_argument('--data_dimension', type=int, default=784, help='Dimensionality of the flattened input images (e.g., 28*28 for MNIST)')
    parser.add_argument('--batch_size', type=int, default=100, help='Training batch size')
    parser.add_argument('--num_epochs', type=int, default=200, help='Number of epochs to train')
    parser.add_argument('--e_path', type=str, default='encoder.pth', help='Path to save the encoder weights')
    parser.add_argument('--d_path', type=str, default='decoder.pth', help='Path to save the decoder weights')
    parser.add_argument('--resume_training', type=bool, default=False, help='Flag to determine if training should be resumed from saved model')

    return parser.parse_args(args=[])  # Here, args=[] is used for illustration in notebooks or scripts without command line args

args = get_args()

In [None]:
vae_model = VAE(args)  # make sure to define args or pass relevant parameters
vae_model.train_model(train_loader, val_loader)

Epoch 1, Average Training Loss: -18.2412
Validation Loss: -44.8321
Epoch 2, Average Training Loss: -68.7700
Epoch 3, Average Training Loss: -109.8089
Epoch 4, Average Training Loss: -114.9820
Epoch 5, Average Training Loss: -118.0587
Epoch 6, Average Training Loss: -118.9946
Epoch 7, Average Training Loss: -119.2355
Epoch 8, Average Training Loss: -119.1165
Epoch 9, Average Training Loss: -119.1593
Epoch 10, Average Training Loss: -119.2186
Epoch 11, Average Training Loss: -119.2931
Validation Loss: -123.2225
Epoch 12, Average Training Loss: -119.5014
Epoch 13, Average Training Loss: -119.4736
Epoch 14, Average Training Loss: -119.6720
Epoch 15, Average Training Loss: -119.4730
Epoch 16, Average Training Loss: -119.6012
Epoch 17, Average Training Loss: -119.4560
Epoch 18, Average Training Loss: -119.5685
Epoch 19, Average Training Loss: -119.5866
Epoch 20, Average Training Loss: -119.6277
Epoch 21, Average Training Loss: -119.6319
Validation Loss: -123.4408
Epoch 22, Average Training L