In [49]:
import pandas as pd
import numpy as np
import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Subset
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torchvision.datasets import ImageFolder
import torch.optim as optim
from torchvision import models
from torchvision.models import vit_b_16, ViT_B_16_Weights
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from torchmetrics.classification import MulticlassF1Score
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
from torch.utils.data import Dataset
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
base_dir = "Images/FloodNet Challenge - Track 1"

In [3]:
# Define data augmentations and preprocessing
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
])

In [6]:
# Load the labeled training data
train_dataset = datasets.ImageFolder(f'{base_dir}/Train/Labeled', transform=train_transforms)
# Get the class names and their corresponding indices
class_to_idx = train_dataset.class_to_idx
print(f"Class to Index Mapping: {class_to_idx}")

# Count the samples in each class
class_counts = {class_name: 0 for class_name in class_to_idx.keys()}
for _, label in train_dataset.samples:
    for class_name, class_idx in class_to_idx.items():
        if label == class_idx:
            class_counts[class_name] += 1

print("Class Distribution:")
for class_name, count in class_counts.items():
    print(f"{class_name}: {count}")

Class to Index Mapping: {'Flooded': 0, 'Non-Flooded': 1}
Class Distribution:
Flooded: 51
Non-Flooded: 347


In [7]:
# Large class imbalance so need to use WeightedRandomSampler to ensure balanced mini-batches
# during training.

# Extract class labels from the dataset
targets = train_dataset.targets

# Calculate class weights, which are the inverse of class frequencies. Classes with fewer samples will get assigned a higher 
# weight (ensuring that the minority class receives a higher weight, making it more likely to be sampled during training).
class_counts = np.bincount(targets)  # Count the number of samples per class
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)

sample_weights = torch.tensor([class_weights[label] for label in targets], dtype=torch.float)

# Create the WeightedRandomSampler
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# Create DataLoaders with the sampler for training
train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler, num_workers=4)

print(f"Number of training samples: {len(train_loader.dataset)}")

Number of training samples: 398


  class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)


In [None]:
# Import the ViT model for transfer learning
model = models.vit_b_16(weights=ViT_B_16_Weights.DEFAULT )
# Freeze the parameters in the base model so only the new layers are being updated
for param in model.parameters():
    param.requires_grad = False

# Replace the final layer with a FC layer
num_classes = len(class_counts)

model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)

# Unfreeze the final layer to allow it to learn during training
for param in model.heads.head.parameters():
    param.requires_grad = True

model = model.to(device)

# Initialize the loss function with the class weights to penalize errors made on the minority class more heavily.
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

# Set up the optimizer
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# Decrease learning rate by 0.1 if loss hasn't decreased after 3 consecutive epochs
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=3, verbose=True
)
# Set the number of epochs for training
num_epochs = 20

# Metrics tracker
f1_metric = MulticlassF1Score(num_classes=num_classes, average='weighted').to(device)

In [53]:
def train_model(model, train_loader, criterion, optimizer, num_epochs, device, 
                is_initial=True, val_loader=None):
    best_loss = float('inf')
    
    save_path = "best_initial_model.pth" if is_initial else "best_retrained_model.pth"
        
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")
        
        # Check if validation should be used for model saving
        if not is_initial and val_loader is not None:
            # Validation step
            model.eval()
            total_confidence = 0.0
            with torch.no_grad():
                for val_images in val_loader:
                    val_images = val_images.to(device)
                    val_outputs = model(val_images)
                    confidences = torch.softmax(val_outputs, dim=1).max(dim=1).values  # Max confidence per sample
                    total_confidence += confidences.sum().item()
            
            avg_confidence = total_confidence / len(val_loader.dataset)
            print(f"Average Confidence: {avg_confidence:.4f}")
        
        else:
            # Save the best model based on training loss
            if avg_train_loss < best_loss:
                best_loss = avg_train_loss
                torch.save(model.state_dict(), save_path)
                print(f"New best model saved with training loss: {best_loss:.4f}")
    
    return model


In [None]:
trained_model = train_model(
    model=model,
    train_loader=train_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=5,
    device=device,
    is_initial=True,
)

In [14]:
def generate_pseudo_labels(model, unlabeled_dataset, batch_size, confidence_threshold=0.9):
    # Path to the best performing initial model
    model_path = os.path.join(".", "best_initial_model.pth")
    model.load_state_dict(torch.load(model_path, weights_only=True))
    print("Loaded best initial model for generation of pseudo labels on unlabeled training data")
    
    # DataLoader for unlabeled data
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    pseudo_labeled_data = []
    
    with torch.no_grad():
        for images, _ in unlabeled_loader:  # The labels are not used for unlabeled data
            images = images.to(device)
            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)  # Get class probabilities
            confidences, pseudo_labels = torch.max(probabilities, dim=1)  # Max confidence and corresponding class
    
            # Filter based on confidence threshold
            for i in range(len(images)):
                if confidences[i] > confidence_threshold:
                    pseudo_labeled_data.append((images[i].cpu(), pseudo_labels[i].cpu()))
    
    print(f"Pseudo-labeled {len(pseudo_labeled_data)} images from the unlabeled dataset.")

    return pseudo_labeled_data

In [35]:
class PseudoLabeledDataset(Dataset):
    def __init__(self, data):
        self.data = data  # data should be a list of (image, label) tuples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image, label = self.data[index]
        # Ensure the image is a tensor, and the label is a tensor of type torch.long
        label = torch.tensor(label, dtype=torch.long)  # Convert label to tensor
        return image, label

In [52]:
class UnlabeledImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if fname.endswith(('png', 'jpg', 'jpeg'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # Ensure image is in RGB mode
        if self.transform:
            image = self.transform(image)
        return image

In [42]:
class TensorLabelDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image, label = self.dataset[index]
        label = torch.tensor(label, dtype=torch.long)  # Convert to tensor
        return image, label

In [26]:
# Load the unlabeled training data
train_unlabeled_dataset = datasets.ImageFolder(f'{base_dir}/Train/Unlabeled', transform=train_transforms)

# Generate the pseudo labels for the unlabeled dataset
pseudo_labeled_data = generate_pseudo_labels(model, train_unlabeled_dataset, 32)

Loaded best initial model for generation of pseudo labels on unlabeled training data
Pseudo-labeled 722 images from the unlabeled dataset.


In [43]:
pseudo_dataset = PseudoLabeledDataset(pseudo_labeled_data)

# Need to use custom dataset that sets labels to tensors instead of ints for uniformity with the unlabeled dataset
train_dataset = TensorLabelDataset(train_dataset)
# Combine labeled and pseudo-labeled datasets
augmented_dataset = ConcatDataset([train_dataset, pseudo_dataset])

In [46]:
# Extract class labels (targets) from the original ImageFolder dataset
original_targets = train_dataset.dataset.targets  # Access the underlying dataset if it's wrapped

# Convert original targets to a list (if they aren't already)
original_targets = list(original_targets)

# Extract pseudo-labels from the pseudo-labeled dataset
pseudo_targets = [label for _, label in pseudo_labeled_data]

# Combine the targets from both datasets
all_targets = original_targets + pseudo_targets

# Calculate class weights (inverse of class frequencies)
class_counts = np.bincount(all_targets)
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)

# Assign weights to each sample in the augmented dataset
sample_weights = torch.tensor([class_weights[label] for label in all_targets], dtype=torch.float)

# Create the WeightedRandomSampler
augmented_sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# Create the DataLoader
augmented_loader = DataLoader(augmented_dataset, batch_size=32, sampler=augmented_sampler, num_workers=4)

In [47]:
# Define data augmentations and preprocessing for the validation dataset
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize the dataset
val_dataset = UnlabeledImageDataset(root_dir=f'{base_dir}/Validation/image', transform=val_transforms)

# Create DataLoader
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
retrained_model = train_model(
    model=model,
    train_loader=augmented_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    device=device,
    is_initial=False,
    val_loader=val_loader
)

  label = torch.tensor(label, dtype=torch.long)  # Convert label to tensor
  label = torch.tensor(label, dtype=torch.long)  # Convert label to tensor
  label = torch.tensor(label, dtype=torch.long)  # Convert label to tensor
  label = torch.tensor(label, dtype=torch.long)  # Convert label to tensor
