In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

from torchvision.utils import make_grid

import matplotlib.pyplot as plt
import numpy as np

In [2]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
import torch
from torchvision import transforms
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt

class EGDDataset(Dataset):
    def __init__(self, root_dir, transform=None, use_test_data=False):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.labels = []

        # Load training data
        train_dir = os.path.join(root_dir, 'train')
        for label in os.listdir(train_dir):
            label_path = os.path.join(train_dir, label)
            for file_name in os.listdir(label_path):
                if file_name.endswith('.BMP'):
                    file_path = os.path.join(label_path, file_name)
                    self.data.append(file_path)
                    self.labels.append(int(label))

        # Optionally load test data
        if use_test_data:
            test_dir = os.path.join(root_dir, 'test')
            for label in os.listdir(test_dir):
                label_path = os.path.join(test_dir, label)
                for file_name in os.listdir(label_path):
                    if file_name.endswith('.BMP'):
                        file_path = os.path.join(label_path, file_name)
                        self.data.append(file_path)
                        self.labels.append(int(label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Image loading
        img_path = self.data[idx]
        image = Image.open(img_path)

        # Label
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        if self.transform:
            image = self.transform(image)

        return image, label
    
def get_EGD_dataloader(root_dir, batch_size, transforms, shuffle=True):
    dataset = EGDDataset(root_dir=root_dir, transform=transforms)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
        

In [3]:
'''# dataloader.py
import os
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt


class_mapping = {'paella': 0, 'macaroni_and_cheese': 1, 'waffles': 2}

class CustomFoodDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        
        # Load images and labels
        for class_name, class_idx in class_mapping.items():
            class_dir = os.path.join(root_dir, 'images', class_name)
            for img_name in os.listdir(class_dir):
                self.images.append(os.path.join(class_dir, img_name))
                self.labels.append(class_idx)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

def get_food101_dataloader(batch_size, root_dir, transform):
    dataset = CustomFoodDataset(root_dir=root_dir, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
])
'''


"# dataloader.py\nimport os\nfrom torchvision import transforms\nfrom torch.utils.data import Dataset, DataLoader\nfrom PIL import Image\nimport matplotlib.pyplot as plt\n\n\nclass_mapping = {'paella': 0, 'macaroni_and_cheese': 1, 'waffles': 2}\n\nclass CustomFoodDataset(Dataset):\n    def __init__(self, root_dir, transform=None):\n        self.root_dir = root_dir\n        self.transform = transform\n        self.images = []\n        self.labels = []\n        \n        # Load images and labels\n        for class_name, class_idx in class_mapping.items():\n            class_dir = os.path.join(root_dir, 'images', class_name)\n            for img_name in os.listdir(class_dir):\n                self.images.append(os.path.join(class_dir, img_name))\n                self.labels.append(class_idx)\n\n    def __len__(self):\n        return len(self.images)\n\n    def __getitem__(self, idx):\n        img_path = self.images[idx]\n        image = Image.open(img_path).convert('RGB')\n        label =

In [4]:

'''    
class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # 16x224x224
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 16x112x112
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # 32x112x112
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 32x56x56
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 64x56x56
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x28x28
        # Optional: You can continue to reduce the number of layers or filters based on memory constraints

        # To compute mu and logvar of the latent vector z we use a convolutional layer
        self.mu_conv = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)  # Adjusted number of filters
        self.logvar_conv = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        # Decoder
        self.deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        # Adjust the final layer to ensure correct output size
        self.deconv4 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1)
        
    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = self.maxpool1(x)
        x = F.relu(self.conv2(x))
        x = self.maxpool2(x)
        x = F.relu(self.conv3(x))
        x = self.maxpool3(x)
        
        mu = self.mu_conv(x) 
        logvar = self.logvar_conv(x) 
        return mu, logvar
        
        
        
    def reparametrize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = F.relu(self.deconv1(z))  # 56x56

        x = F.relu(self.deconv2(x))  # 112x112

        x = F.relu(self.deconv3(x))  # 224x224

        x = torch.sigmoid(self.deconv4(x))  # 224x224

        return x
    

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar
    
'''


'    \nclass ConvVAE(nn.Module):\n    def __init__(self):\n        super(ConvVAE, self).__init__()\n\n        # Encoder\n        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # 16x224x224\n        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 16x112x112\n        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # 32x112x112\n        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 32x56x56\n        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 64x56x56\n        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x28x28\n        # Optional: You can continue to reduce the number of layers or filters based on memory constraints\n\n        # To compute mu and logvar of the latent vector z we use a convolutional layer\n        self.mu_conv = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)  # Adjusted number of filters\n        self.logvar_conv = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)  # B X N X C
        key = self.key_conv(x).view(batch_size, -1, width * height)  # B X C x N
        energy = torch.bmm(query, key)  # Batch matrix multiplication, B X N X N
        attention_map = torch.softmax(energy, dim=-1)  # Softmax over the last dimension to create attention map
        value = self.value_conv(x).view(batch_size, -1, width * height)  # B X C X N
        out = torch.bmm(value, attention_map.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x  # Apply attention gamma and add input
        return out, attention_map

class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()
        
        # Encoder
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.att1 = SelfAttention(in_dim=32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.att2 = SelfAttention(in_dim=64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc_mu = nn.Linear(128 * 28 * 28, 256)
        self.fc_logvar = nn.Linear(128 * 28 * 28, 256)
        
        # Decoder
        self.fc_decode = nn.Linear(256, 128 * 28 * 28)
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        x = self.relu(self.conv1(x))
        x, attention_map1 = self.att1(x)
        x = self.relu(self.conv2(x))
        x, attention_map2 = self.att2(x)
        x = self.relu(self.conv3(x))
        x = x.view(-1, 128 * 28 * 28)  # Flatten the convolutional layer output
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
        
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        x = self.relu(self.fc_decode(z))
        x = x.view(-1, 128, 28, 28)  # Unflatten to prepare for transposed convolutions
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        x = self.sigmoid(self.deconv3(x))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar

# Now, let's instantiate and view the model to ensure it's structured correctly.
conv_vae = ConvVAE()
print(conv_vae)


ConvVAE(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (att1): SelfAttention(
    (query_conv): Conv2d(32, 4, kernel_size=(1, 1), stride=(1, 1))
    (key_conv): Conv2d(32, 4, kernel_size=(1, 1), stride=(1, 1))
    (value_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (att2): SelfAttention(
    (query_conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
    (key_conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
    (value_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (fc_mu): Linear(in_features=100352, out_features=256, bias=True)
  (fc_logvar): Linear(in_features=100352, out_features=256, bias=True)
  (fc_decode): Linear(in_features=256, out_features=100352, bias=True)
  (deconv1): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), outp

In [6]:
#ConvVAE loss function
#Loss will be the sum of the reconstruction loss and the KL divergence loss
#The reconstruction loss will be Mean Squared Error (MSE) loss
#The KL divergence loss will be the KL divergence between the learned mean and variance and the prior Gaussian distribution

def loss_function(recon_x, x, mu, logvar):
    # Normalized MSE loss
    recon_loss = F.mse_loss(recon_x.view(-1, 3*224*224), x.view(-1, 3*224*224), reduction='sum') / x.size(0)
    
    # KL divergence loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    
    return recon_loss + KLD

In [7]:
import matplotlib.pyplot as plt
import wandb
def log_images(original, label, reconstructed):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    # Original image
    axs[0].imshow(original.permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
    axs[0].set_title(f'Original Image, Label: {label}')
    # Reconstructed image
    axs[1].imshow(reconstructed.permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
    axs[1].set_title('Reconstructed Image')
    plt.axis('off')
    wandb.log({"Original vs reconstructed image": plt})


In [8]:
# Define number of epochs
num_epochs = 1000

# Define device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create model
model = ConvVAE()

# Move model to device
model.to(device)

#learning rate
lr = 3e-4

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

# Create dataloader
root_dir = '/home/ndelafuente/VAE_food101/data_egd'
batch_size = 4

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_loader = get_EGD_dataloader(root_dir=root_dir, batch_size=batch_size, transforms=transform)

#Initialize wandb
import wandb

wandb.init(project="vae_egd", entity="neildlf")


# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    for batch_idx, (data, label) in enumerate(train_loader):
        # Move data to device
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass through the model
        recon_batch, mu, logvar = model(data)
        
        # Compute loss
        loss = loss_function(recon_batch, data, mu, logvar)
        train_loss += loss.item()

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()/len(data)}')
            # Visualize the first image in the batch
            
            label = label[0] # Get the label for the first image in the batch which is what we will visualize
            log_images(data[0], label.item(), recon_batch[0])
            # Log the loss
            wandb.log({"loss": loss.item()/len(data)})
            
    # Print average loss for the epoch
    print(f'====> Epoch: {epoch}, Average loss: {train_loss / len(train_loader.dataset)}')
    wandb.log({"average_loss": train_loss / len(train_loader.dataset)})
    # Save the model every 2 epochs
    if epoch % 100 == 0:
        torch.save(model.state_dict(), f'vae_egd_epoch_{epoch}.pth')


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mneildlf[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch: 0, Batch: 0, Loss: 3171.4462890625
====> Epoch: 0, Average loss: 2176.082705145013
Epoch: 1, Batch: 0, Loss: 1452.3004150390625
====> Epoch: 1, Average loss: 996.5808699098352
Epoch: 2, Batch: 0, Loss: 701.3031616210938
====> Epoch: 2, Average loss: 643.1130433801102
Epoch: 3, Batch: 0, Loss: 515.1572875976562
====> Epoch: 3, Average loss: 515.4676200135114
Epoch: 4, Batch: 0, Loss: 544.792724609375
====> Epoch: 4, Average loss: 457.6659721479024
Epoch: 5, Batch: 0, Loss: 360.4516906738281
====> Epoch: 5, Average loss: 412.8030583629869
Epoch: 6, Batch: 0, Loss: 492.7398376464844
====> Epoch: 6, Average loss: 386.81776365515304
Epoch: 7, Batch: 0, Loss: 310.71649169921875
====> Epoch: 7, Average loss: 362.12177809623824
Epoch: 8, Batch: 0, Loss: 314.9486999511719
====> Epoch: 8, Average loss: 340.6172934754254
Epoch: 9, Batch: 0, Loss: 297.1343078613281
====> Epoch: 9, Average loss: 326.68792390170165
Epoch: 10, Batch: 0, Loss: 389.5758056640625
====> Epoch: 10, Average loss: 32

In [None]:
model = ConvVAE()
dummy_input = torch.randn(1, 3, 224, 224)  # Batch size of 1
recon, mu, logvar = model(dummy_input)
print("Reconstructed size:", recon.size())  # Should be [1, 3, 224, 224]


mu shape: torch.Size([1, 64, 28, 28])
logvar shape: torch.Size([1, 64, 28, 28])
z shape: torch.Size([1, 64, 28, 28])
x shape after deconv1: torch.Size([1, 64, 56, 56])
x shape after deconv2: torch.Size([1, 32, 112, 112])
x shape after deconv3: torch.Size([1, 16, 224, 224])
x shape after deconv4: torch.Size([1, 3, 224, 224])
Reconstructed size: torch.Size([1, 3, 224, 224])
