In [37]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import copy
import os
import subprocess

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

from torch.utils.data import DataLoader, Dataset

In [38]:
from utils.datasets import WildfireDataset

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = WildfireDataset('/data/amathur-23/ROB313', split='train', labeled=False, transforms=transform)
data_train_labeled = WildfireDataset('/data/amathur-23/ROB313', split='train', labeled=True, transforms=transform)
val_dataset = WildfireDataset('/data/amathur-23/ROB313', split='val', transforms=transform)
test_dataset = WildfireDataset('/data/amathur-23/ROB313', split='test', transforms=transform)

Loading meta file: /data/amathur-23/ROB313/train_unlabeled.csv
Loading meta file: /data/amathur-23/ROB313/train.csv
Loading meta file: /data/amathur-23/ROB313/val.csv
Loading meta file: /data/amathur-23/ROB313/test.csv


In [39]:

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_loader_labeled = DataLoader(data_train_labeled, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [5]:
class ConvVAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(ConvVAE, self).__init__()
        
        # Encoder
        # 3x224x224
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1), # 224 -> 112
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1), # 112 -> 56
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), # 56 -> 28
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # 28 -> 14
            nn.ReLU(),
            nn.Conv2d(256, 256, 4, stride=2, padding=1),  # 14 -> 7
            nn.ReLU()
        )
        
        self.encoder_output_dim = (256 * 7 * 7)
        self.fc_mu = nn.Linear(self.encoder_output_dim, latent_dim)
        self.fc_var = nn.Linear(self.encoder_output_dim, latent_dim)
        
        # Decoder
        self.decoder_input = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.Linear(256, self.encoder_output_dim)
        ) 
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        batch_size = x.size(0)
        x = self.encoder(x)
        x = x.view(batch_size,-1)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(x.size(0), 256, 7, 7) 
        x = self.decoder(x)
        return x
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        z = self.decode(z)
        return z, mu, log_var

In [6]:
class BetaVAELoss(nn.Module):
    def __init__(self, beta=1):
        super(BetaVAELoss, self).__init__()
        self.beta = beta
        
    def forward(self, x, recon_x, mu, logvar):
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + self.beta * kl_loss
    
criterion_vae = BetaVAELoss(beta=1)

In [7]:
from tqdm import tqdm
def train(model, dataloader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, f"Training {epoch}"):
        data = batch['image'].to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = criterion_vae(data, recon_batch, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader.dataset)

# Validation Function
def validate(model, dataloader, device, epoch):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, f"Validation {epoch}"):
            data = batch['image'].to(device)
            recon_batch, mu, logvar = model(data)
            loss = criterion_vae(data, recon_batch, mu, logvar)
            total_loss += loss.item()
    return total_loss / len(dataloader.dataset)

In [10]:
latent_dim = 256
learning_rate = 1e-5
num_epochs = 50

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ConvVAE(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_losses = []
val_losses = []
# for epoch in range(num_epochs):
#     train_loss = train(model, train_loader, optimizer, device, epoch)
#     print(f"Epoch {epoch} Train loss: {train_loss}")
#     val_loss = validate(model, val_loader, device, epoch)
#     print(f"Epoch {epoch} Validation loss: {val_loss}")

In [8]:
from sklearn.cluster import KMeans, DBSCAN
from sklearn.mixture import GaussianMixture
def perform_clustering(features, method="kmeans", num_clusters=2):
    if method == "kmeans":
        clustering = KMeans(n_clusters=num_clusters, random_state=42).fit(features)
    elif method == "gmm":
        clustering = GaussianMixture(n_components=num_clusters, random_state=42).fit(features)
    elif method == "dbscan":
        clustering = DBSCAN(eps=0.5, min_samples=5).fit(features)
    else:
        raise ValueError("Unsupported clustering method")
    return clustering.labels_


# labels = perform_clustering(labelled_features, method="kmeans", num_clusters=2)

In [11]:
class ClassifierFeatures(nn.Module):
    def __init__(self, vae, device, input_dim=256, dropout=0.1):
        super(ClassifierFeatures, self).__init__()
        self.vae = vae.to(device)  
        self.vae.eval()  
        self.device = device
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1),
            nn.Sigmoid()
        ).to(device)  

    def forward(self, x):
        with torch.no_grad():
            mu, logvar = self.vae.encode(x)
            x = self.vae.reparameterize(mu, logvar)
        return self.fc(x)

In [34]:
class ClassifierFeatures_Coords(nn.Module):
    def __init__(self, vae, device, input_dim=256, dropout=0.1):
        super(ClassifierFeatures_Coords, self).__init__()
        self.vae = vae
        self.vae.eval()  
        self.fc = nn.Sequential(
            nn.Linear(input_dim+2, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, coords):
        with torch.no_grad():
            mu, logvar = self.vae.encode(x)
            x = self.vae.reparameterize(mu, logvar)
        x = torch.cat((x, coords), dim=1)
        return self.fc(x)

In [27]:
from sklearn.metrics import f1_score
def train_classifier(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in tqdm(train_loader):
        target = batch['label'].float().to(device)  
        image = batch['image'].to(device)
        coords = batch['coords'].to(device)
        optimizer.zero_grad()
        output = model(image, coords).squeeze()  
        
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        predicted = (output > 0.5).float() 
        correct += (predicted == target).sum().item()
        total += target.size(0)

    accuracy = 100. * correct / total
    return total_loss / len(train_loader), accuracy

def validate_classifier(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            target = batch['label'].float().to(device)
            image = batch['image'].to(device)
            coords = batch['coords'].to(device)
            output = model(image, coords).squeeze()
            loss = criterion(output, target)

            total_loss += loss.item()
            predicted = (output > 0.5).float()  
            correct += (predicted == target).sum().item()
            total += target.size(0)
            all_preds.append(predicted.cpu().numpy())
            all_targets.append(target.cpu().numpy())
            
        f1 = f1_score(np.concatenate(all_targets), np.concatenate(all_preds))
        print(f'Validation Loss: {total_loss / len(val_loader)}')
        print(f'Validation Accuracy: {100. * correct / total}')
        print(f'Validation F1 Score: {f1}')
        return total_loss / len(val_loader)

In [46]:
def extract_features(model, dataloader, device, labels = True):
    model.eval()
    all_features = []
    all_targets = []
    with torch.no_grad():
        for batch in dataloader:
            data = batch['image'].to(device)
            target = batch['label'].float().to(device) if labels else None
            mu, _ = model.encode(data)
            all_features.append(mu.cpu().numpy())
            all_targets.append(target.cpu().numpy()) if labels else None
    if labels:
        return np.concatenate(all_features), np.concatenate(all_targets)
    return np.concatenate(all_features)

In [35]:
latent_dim = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ConvVAE(latent_dim=latent_dim).to(device)
model.load_state_dict(torch.load('/data/iivanova-23/ROB313/models/vae_trial_256_32/vae_model.pth'))
classifier = ClassifierFeatures_Coords(model, device, input_dim=latent_dim).to(device) 
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)
criterion = nn.BCELoss()
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_classifier(classifier, train_loader_labeled, optimizer,criterion, device)
    print(f"Epoch {epoch} Train loss: {train_loss}")
    val_loss = validate_classifier(classifier, val_loader, criterion, device)
    print(f"Epoch {epoch} Validation loss: {val_loss}")

# features = extract_features(model, train_loader, device, labels=False)
# labels = perform_clustering(features, method="kmeans", num_clusters=2)

100%|██████████| 79/79 [00:25<00:00,  3.08it/s]


Epoch 0 Train loss: (0.6275317276580424, 69.5436507936508)


100%|██████████| 20/20 [00:06<00:00,  3.03it/s]


Validation Loss: 0.5598565071821213
Validation Accuracy: 81.03174603174604
Validation F1 Score: 0.8503443957420163
Epoch 0 Validation loss: 0.5598565071821213


100%|██████████| 79/79 [00:24<00:00,  3.18it/s]


Epoch 1 Train loss: (0.5151921552193316, 79.98015873015873)


100%|██████████| 20/20 [00:06<00:00,  3.15it/s]


Validation Loss: 0.42805579751729966
Validation Accuracy: 85.31746031746032
Validation F1 Score: 0.8804137039431157
Epoch 1 Validation loss: 0.42805579751729966


100%|██████████| 79/79 [00:25<00:00,  3.10it/s]


Epoch 2 Train loss: (0.40536109103432183, 84.7420634920635)


100%|██████████| 20/20 [00:06<00:00,  3.23it/s]


Validation Loss: 0.34037434607744216
Validation Accuracy: 86.74603174603175
Validation F1 Score: 0.8907782864617397
Epoch 2 Validation loss: 0.34037434607744216


100%|██████████| 79/79 [00:24<00:00,  3.26it/s]


Epoch 3 Train loss: (0.34006730589685563, 87.26190476190476)


100%|██████████| 20/20 [00:06<00:00,  3.28it/s]


Validation Loss: 0.2867875225841999
Validation Accuracy: 90.15873015873017
Validation F1 Score: 0.9159891598915989
Epoch 3 Validation loss: 0.2867875225841999


100%|██████████| 79/79 [00:24<00:00,  3.24it/s]


Epoch 4 Train loss: (0.2899406126028375, 89.24603174603175)


100%|██████████| 20/20 [00:06<00:00,  3.29it/s]


Validation Loss: 0.24709537103772164
Validation Accuracy: 92.38095238095238
Validation F1 Score: 0.9337016574585635
Epoch 4 Validation loss: 0.24709537103772164


100%|██████████| 79/79 [00:23<00:00,  3.31it/s]


Epoch 5 Train loss: (0.25824374884744233, 90.77380952380952)


100%|██████████| 20/20 [00:06<00:00,  3.33it/s]


Validation Loss: 0.22053931206464766
Validation Accuracy: 93.73015873015873
Validation F1 Score: 0.9440905874026894
Epoch 5 Validation loss: 0.22053931206464766


100%|██████████| 79/79 [00:24<00:00,  3.19it/s]


Epoch 6 Train loss: (0.24250185565103458, 91.19047619047619)


100%|██████████| 20/20 [00:06<00:00,  3.22it/s]


Validation Loss: 0.20223401337862015
Validation Accuracy: 94.44444444444444
Validation F1 Score: 0.95
Epoch 6 Validation loss: 0.20223401337862015


100%|██████████| 79/79 [00:23<00:00,  3.32it/s]


Epoch 7 Train loss: (0.2285790360426601, 91.56746031746032)


100%|██████████| 20/20 [00:05<00:00,  3.42it/s]


Validation Loss: 0.19233059883117676
Validation Accuracy: 94.36507936507937
Validation F1 Score: 0.9491768074445239
Epoch 7 Validation loss: 0.19233059883117676


100%|██████████| 79/79 [00:24<00:00,  3.25it/s]


Epoch 8 Train loss: (0.2202656407710872, 91.76587301587301)


100%|██████████| 20/20 [00:05<00:00,  3.35it/s]


Validation Loss: 0.17730557322502136
Validation Accuracy: 94.52380952380952
Validation F1 Score: 0.9508196721311475
Epoch 8 Validation loss: 0.17730557322502136


100%|██████████| 79/79 [00:23<00:00,  3.33it/s]


Epoch 9 Train loss: (0.2052654890885836, 92.89682539682539)


100%|██████████| 20/20 [00:05<00:00,  3.38it/s]

Validation Loss: 0.17733179740607738
Validation Accuracy: 94.92063492063492
Validation F1 Score: 0.9540889526542324
Epoch 9 Validation loss: 0.17733179740607738





In [36]:
validate_classifier(classifier, test_loader, nn.BCELoss(), device)

100%|██████████| 99/99 [00:29<00:00,  3.33it/s]

Validation Loss: 0.1912333754578022
Validation Accuracy: 93.42752817907605
Validation F1 Score: 0.9418539325842696





0.1912333754578022

## Cluster features from resnet
