In [76]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import copy
import os
import subprocess

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

from torch.utils.data import DataLoader, Dataset
from typing import List, Callable, Union, Any, TypeVar, Tuple

import torchvision.models as models

In [None]:
class VectorQuantizer(nn.Module):
    """
    Reference:
    [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
    """
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 beta: float = 0.25):
        super(VectorQuantizer, self).__init__()
        self.K = num_embeddings
        self.D = embedding_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.K, self.D)
        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)

    def forward(self, latents: torch.Tensor) -> torch.Tensor:
        latents = latents.permute(0, 2, 3, 1).contiguous()  # [B x D x H x W] -> [B x H x W x D]
        latents_shape = latents.shape
        flat_latents = latents.view(-1, self.D)  # [BHW x D]

        # Compute L2 distance between latents and embedding weights
        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
               torch.sum(self.embedding.weight ** 2, dim=1) - \
               2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHW x K]

        # Get the encoding that has the min distance
        encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHW, 1]

        # Convert to one-hot encodings
        device = latents.device
        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]

        # Quantize the latents
        quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]
        quantized_latents = quantized_latents.view(latents_shape)  # [B x H x W x D]

        # Compute the VQ Losses
        commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
        embedding_loss = F.mse_loss(quantized_latents, latents.detach())

        vq_loss = commitment_loss * self.beta + embedding_loss

        # Add the residue back to the latents
        quantized_latents = latents + (quantized_latents - latents).detach()

        return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss  # [B x D x H x W]

class ResidualLayer(nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int):
        super(ResidualLayer, self).__init__()
        self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,
                                                kernel_size=3, padding=1, bias=False),
                                      nn.ReLU(True),
                                      nn.Conv2d(out_channels, out_channels,
                                                kernel_size=1, bias=False))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input + self.resblock(input)


class VQVAE(nn.Module):

    def __init__(self,
                 in_channels: int,
                 embedding_dim: int,
                 num_embeddings: int,
                 hidden_dims: List = None,
                 beta: float = 0.25,
                 pretrained: bool = False,
                 **kwargs) -> None:
        super(VQVAE, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta
        self.pretrained = pretrained
        
        # Resnet Encoder
        if self.pretrained:
            resnet = models.resnet50(weights="DEFAULT")
            modules = list(resnet.children())[:-1]
            self.resnet = nn.Sequential(*modules)
            in_channels = 2048
            
        modules = []
        if hidden_dims is None:
            hidden_dims = [128, 256]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=3, stride=1, padding=1),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, in_channels,
                          kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU())
        )

        for _ in range(6):
            modules.append(ResidualLayer(in_channels, in_channels))
        modules.append(nn.LeakyReLU())

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, embedding_dim,
                          kernel_size=1, stride=1),
                nn.LeakyReLU())
        )

        self.encoder = nn.Sequential(*modules)

        self.vq_layer = VectorQuantizer(num_embeddings,
                                        embedding_dim,
                                        self.beta)

        # Build Decoder
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(embedding_dim,
                          hidden_dims[-1],
                          kernel_size=3,
                          stride=1,
                          padding=1),
                nn.LeakyReLU())
        )

        for _ in range(6):
            modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))

        modules.append(nn.LeakyReLU())

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=4,
                                       stride=2,
                                       padding=1),
                    nn.LeakyReLU())
            )

        modules.append(
            nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1] // 2, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(hidden_dims[-1] // 2, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()))
        modules.append(nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False))

        self.decoder = nn.Sequential(*modules)

    def encode(self, input: torch.Tensor) -> List[torch.Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        if self.pretrained:
            input = self.resnet(input)
        result = self.encoder(input)
        return result

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        result = self.decoder(z)
        return result

    def forward(self, input: torch.Tensor, **kwargs) -> List[torch.Tensor]:
        # input = input.to(next(self.parameters()).dtype)  # Ensure input type matches model parameters
        # print("Input Shape:", input.shape)
        encoding = self.encode(input)
        quantized_inputs, vq_loss = self.vq_layer(encoding)
        z = self.decode(quantized_inputs)
        return [z, input , vq_loss, quantized_inputs]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        vq_loss = args[2]
        # print("Reconstructed Shape:", recons.shape)
        recons_loss = F.mse_loss(recons, input)
        
        loss = recons_loss + vq_loss
        return {'loss': loss,
                'Reconstruction_Loss': recons_loss,
                'VQ_Loss':vq_loss}

    def sample(self,
               num_samples: int,
               current_device: Union[int, str], **kwargs) -> torch.Tensor:
        raise Warning('VQVAE sampler is not implemented.')

    def generate(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [81]:
from tqdm import tqdm
from torch.amp import autocast
def train(model, dataloader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, f"Training {epoch}"):
        data = batch['image'].to(device)
        optimizer.zero_grad()
        with autocast(
            device_type="cuda", dtype=torch.float16
        ):
            
            args = model(data)
            loss = model.loss_function(*args)['loss']
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader.dataset)

# Validation Function
def validate(model, dataloader, device, epoch):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, f"Validation {epoch}"):
            data = batch['image'].to(device)
            
            with autocast(
                device_type="cuda", dtype=torch.float16
            ):
                args = model(data)
                loss = model.loss_function(*args)['loss']
            total_loss += loss.item()
    return total_loss / len(dataloader.dataset)

In [82]:
from utils.datasets import WildfireDataset

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = WildfireDataset('/data/amathur-23/ROB313', split='train', labeled=False, transforms=transform)
data_train_labeled = WildfireDataset('/data/amathur-23/ROB313', split='train', labeled=True, transforms=transform)
val_dataset = WildfireDataset('/data/amathur-23/ROB313', split='val', transforms=transform)
test_dataset = WildfireDataset('/data/amathur-23/ROB313', split='test', transforms=transform)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_loader_labeled = DataLoader(data_train_labeled, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Loading meta file: /data/amathur-23/ROB313/train_unlabeled.csv
Loading meta file: /data/amathur-23/ROB313/train.csv
Loading meta file: /data/amathur-23/ROB313/val.csv
Loading meta file: /data/amathur-23/ROB313/test.csv


In [74]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
learning_rate = 1e-4
num_epochs = 4

In [117]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VQVAE(in_channels=3, embedding_dim=64, num_embeddings=512, hidden_dims=[128, 256], beta=0.25, pretrained=True)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
learning_rate = 1e-4
num_epochs = 3
train_losses = []
val_losses = []
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, device, epoch)
    print(f"Epoch {epoch} Train loss: {train_loss}")
    val_loss = validate(model, val_loader, device, epoch)
    print(f"Epoch {epoch} Validation loss: {val_loss}")


Training 0: 100%|██████████| 946/946 [00:51<00:00, 18.34it/s]


Epoch 0 Train loss: 0.28576378948200065


Validation 0: 100%|██████████| 40/40 [00:01<00:00, 22.71it/s]


Epoch 0 Validation loss: 0.012841783866049752


Training 1: 100%|██████████| 946/946 [00:52<00:00, 18.02it/s]


Epoch 1 Train loss: 0.014188698789037639


Validation 1: 100%|██████████| 40/40 [00:01<00:00, 23.02it/s]


Epoch 1 Validation loss: 0.0139276926243116


Training 2: 100%|██████████| 946/946 [00:54<00:00, 17.21it/s]


Epoch 2 Train loss: 0.013792412497354202


Validation 2: 100%|██████████| 40/40 [00:01<00:00, 21.66it/s]

Epoch 2 Validation loss: 0.012901409963766734





In [109]:
class ClassifierFeatures_Coords(nn.Module):
    def __init__(self, vae, input_dim=256, dropout=0.1):
        super(ClassifierFeatures_Coords, self).__init__()
        self.vae = vae
        self.vae.eval()  
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, coords=None):
        with torch.no_grad():
            x = self.vae.encode(x)
            x, _ = self.vae.vq_layer(x)
        batch_size = x.shape[0]
        x = x.view(batch_size, -1) 
        x = torch.cat((x, coords), dim=1)
        return self.fc(x)

In [110]:
from sklearn.metrics import f1_score
def train_classifier(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in tqdm(train_loader):
        target = batch['label'].float().to(device)  
        image = batch['image'].to(device)
        coords = batch['coords'].to(device)
        optimizer.zero_grad()
        output = model(image, coords).squeeze()  
        
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        predicted = (output > 0.5).float() 
        correct += (predicted == target).sum().item()
        total += target.size(0)

    accuracy = 100. * correct / total
    return total_loss / len(train_loader), accuracy

def validate_classifier(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            target = batch['label'].float().to(device)
            image = batch['image'].to(device)
            coords = batch['coords'].to(device)
            output = model(image, coords).squeeze()
            loss = criterion(output, target)

            total_loss += loss.item()
            predicted = (output > 0.5).float()  
            correct += (predicted == target).sum().item()
            total += target.size(0)
            all_preds.append(predicted.cpu().numpy())
            all_targets.append(target.cpu().numpy())
            
        f1 = f1_score(np.concatenate(all_targets), np.concatenate(all_preds))
        # print(f'Validation Loss: {total_loss / len(val_loader)}')
        # print(f'Validation Accuracy: {100. * correct / total}')
        # print(f'Validation F1 Score: {f1}')
        return total_loss / len(val_loader), correct / total, f1

In [119]:
latent_dim = 66
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classifier = ClassifierFeatures_Coords(model, input_dim=latent_dim).to(device) 
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)
criterion_classif = nn.BCELoss()
num_epochs = 15
for epoch in range(num_epochs):
    train_loss = train_classifier(classifier, train_loader_labeled, optimizer,criterion_classif, device)
    print(f"Epoch {epoch} Train loss: {train_loss}")
    val_loss = validate_classifier(classifier, val_loader, criterion_classif, device)
    print(f"Epoch {epoch} Validation loss: {val_loss}")


100%|██████████| 158/158 [00:05<00:00, 29.16it/s]


Epoch 0 Train loss: (0.5166283981709541, 73.37301587301587)


100%|██████████| 40/40 [00:01<00:00, 22.29it/s]


Epoch 0 Validation loss: (0.33649000972509385, 0.8428571428571429, 0.8735632183908046)


100%|██████████| 158/158 [00:05<00:00, 29.14it/s]


Epoch 1 Train loss: (0.34242566288272036, 83.88888888888889)


100%|██████████| 40/40 [00:01<00:00, 23.34it/s]


Epoch 1 Validation loss: (0.26953173987567425, 0.8960317460317461, 0.9081990189208129)


100%|██████████| 158/158 [00:05<00:00, 29.71it/s]


Epoch 2 Train loss: (0.28682282557593114, 86.7063492063492)


100%|██████████| 40/40 [00:01<00:00, 22.66it/s]


Epoch 2 Validation loss: (0.23515386022627355, 0.9253968253968254, 0.930780559646539)


100%|██████████| 158/158 [00:05<00:00, 29.66it/s]


Epoch 3 Train loss: (0.2564816809719122, 88.41269841269842)


100%|██████████| 40/40 [00:01<00:00, 21.44it/s]


Epoch 3 Validation loss: (0.21363723538815976, 0.9031746031746032, 0.9168937329700273)


100%|██████████| 158/158 [00:05<00:00, 30.11it/s]


Epoch 4 Train loss: (0.24876164494058753, 88.71031746031746)


100%|██████████| 40/40 [00:01<00:00, 23.48it/s]


Epoch 4 Validation loss: (0.2000930305570364, 0.9373015873015873, 0.9429602888086642)


100%|██████████| 158/158 [00:05<00:00, 28.78it/s]


Epoch 5 Train loss: (0.2412049202602121, 89.38492063492063)


100%|██████████| 40/40 [00:01<00:00, 23.37it/s]


Epoch 5 Validation loss: (0.20517232529819013, 0.9111111111111111, 0.9141104294478528)


100%|██████████| 158/158 [00:05<00:00, 30.14it/s]


Epoch 6 Train loss: (0.23313495629950415, 89.56349206349206)


100%|██████████| 40/40 [00:01<00:00, 22.71it/s]


Epoch 6 Validation loss: (0.21387940356507898, 0.9, 0.9149797570850202)


100%|██████████| 158/158 [00:05<00:00, 31.05it/s]


Epoch 7 Train loss: (0.23250562942857983, 89.76190476190476)


100%|██████████| 40/40 [00:01<00:00, 23.61it/s]


Epoch 7 Validation loss: (0.19509986899793147, 0.9206349206349206, 0.9298737727910238)


100%|██████████| 158/158 [00:05<00:00, 31.07it/s]


Epoch 8 Train loss: (0.22472561004606983, 90.55555555555556)


100%|██████████| 40/40 [00:01<00:00, 23.36it/s]


Epoch 8 Validation loss: (0.22616538871079683, 0.8841269841269841, 0.8841269841269841)


100%|██████████| 158/158 [00:05<00:00, 30.61it/s]


Epoch 9 Train loss: (0.22666252913731563, 90.23809523809524)


100%|██████████| 40/40 [00:01<00:00, 23.23it/s]


Epoch 9 Validation loss: (0.19455335047096015, 0.9134920634920635, 0.9251887439945092)


100%|██████████| 158/158 [00:05<00:00, 30.81it/s]


Epoch 10 Train loss: (0.22592095879814292, 90.27777777777777)


100%|██████████| 40/40 [00:01<00:00, 22.75it/s]


Epoch 10 Validation loss: (0.18829679843038322, 0.926984126984127, 0.9356643356643357)


100%|██████████| 158/158 [00:05<00:00, 29.34it/s]


Epoch 11 Train loss: (0.22544580452804325, 90.05952380952381)


100%|██████████| 40/40 [00:01<00:00, 22.64it/s]


Epoch 11 Validation loss: (0.1907964127138257, 0.9174603174603174, 0.9279778393351801)


100%|██████████| 158/158 [00:05<00:00, 30.29it/s]


Epoch 12 Train loss: (0.21921639979074273, 90.87301587301587)


100%|██████████| 40/40 [00:01<00:00, 23.29it/s]


Epoch 12 Validation loss: (0.17707180129364133, 0.9380952380952381, 0.9441260744985673)


100%|██████████| 158/158 [00:05<00:00, 30.45it/s]


Epoch 13 Train loss: (0.2152173757741723, 90.9920634920635)


100%|██████████| 40/40 [00:01<00:00, 23.53it/s]


Epoch 13 Validation loss: (0.17884824145585299, 0.9357142857142857, 0.9422665716322167)


100%|██████████| 158/158 [00:05<00:00, 29.82it/s]


Epoch 14 Train loss: (0.2158301167095764, 90.67460317460318)


100%|██████████| 40/40 [00:01<00:00, 21.51it/s]

Epoch 14 Validation loss: (0.17349596777930856, 0.9357142857142857, 0.9409190371991247)





In [120]:
validate_classifier(classifier, test_loader, criterion_classif, device)

100%|██████████| 197/197 [00:06<00:00, 30.78it/s]


(0.18376188817892583, 0.9287188442609938, 0.9360125409719253)