In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("deepnewbie/flir-thermal-images-dataset")

print("Path to dataset files:", path)

In [None]:
# List files and folders in the dataset directory
import os

files = os.listdir(path)
print("Files and folders in dataset root:")
for f in files:
    print(f)

# Print all folders and subfolders recursively
print("\nAll folders and subfolders:")
for root, dirs, files in os.walk(path):
    print(f"Folder: {root}")
    for d in dirs:
        print(f"  Subfolder: ",d)

In [None]:
# Set RGB and thermal image directories for train and val splits
train_rgb_dir = os.path.join(path, "FLIR_ADAS_1_3", "train", "RGB")
train_thermal_dir = os.path.join(path, "FLIR_ADAS_1_3", "train", "thermal_16_bit")
val_rgb_dir = os.path.join(path, "FLIR_ADAS_1_3", "val", "RGB")
val_thermal_dir = os.path.join(path, "FLIR_ADAS_1_3", "val", "thermal_16_bit")

print("Train RGB directory:", train_rgb_dir)
print("Train Thermal directory:", train_thermal_dir)
print("Val RGB directory:", val_rgb_dir)
print("Val Thermal directory:", val_thermal_dir)

# Match images by filename prefix (without extension, ignoring extension differences)
def get_matched_files_by_prefix(rgb_dir, thermal_dir, exclude_prefix=None):
    rgb_files = [f for f in os.listdir(rgb_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    thermal_files = [f for f in os.listdir(thermal_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    rgb_prefix = {os.path.splitext(f)[0]: f for f in rgb_files}
    thermal_prefix = {os.path.splitext(f)[0]: f for f in thermal_files}
    common_prefixes = sorted(set(rgb_prefix.keys()) & set(thermal_prefix.keys()))
    if exclude_prefix is not None:
        common_prefixes = [p for p in common_prefixes if p != exclude_prefix]
    matched_rgb = [rgb_prefix[p] for p in common_prefixes]
    matched_thermal = [thermal_prefix[p] for p in common_prefixes]
    return matched_rgb, matched_thermal, common_prefixes

# Exclude FLIR_00001 from training
matched_rgb, matched_thermal, matched_prefixes = get_matched_files_by_prefix(train_rgb_dir, train_thermal_dir, exclude_prefix="FLIR_00001")
val_matched_rgb, val_matched_thermal, val_matched_prefixes = get_matched_files_by_prefix(val_rgb_dir, val_thermal_dir)
print(f"Number of matched train RGB images: {len(matched_rgb)}")
print(f"Number of matched train thermal images: {len(matched_thermal)}")
print(f"Number of matched val RGB images: {len(val_matched_rgb)}")
print(f"Number of matched val thermal images: {len(val_matched_thermal)}")

In [7]:
# Define the custom dataset and U-Net model for RGB to thermal translation
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as T

class FLIRRGB2ThermalDataset(Dataset):
    def __init__(self, rgb_dir, thermal_dir, img_size=256, matched_rgb=None, matched_thermal=None):
        self.rgb_dir = rgb_dir
        self.thermal_dir = thermal_dir
        self.img_size = img_size
        # Use matched file lists if provided, else match by prefix
        if matched_rgb is not None and matched_thermal is not None:
            self.rgb_files = matched_rgb
            self.thermal_files = matched_thermal
        else:
            rgb_files = [f for f in os.listdir(rgb_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            thermal_files = [f for f in os.listdir(thermal_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            rgb_prefix = {os.path.splitext(f)[0]: f for f in rgb_files}
            thermal_prefix = {os.path.splitext(f)[0]: f for f in thermal_files}
            common_prefixes = sorted(set(rgb_prefix.keys()) & set(thermal_prefix.keys()))
            self.rgb_files = [rgb_prefix[p] for p in common_prefixes]
            self.thermal_files = [thermal_prefix[p] for p in common_prefixes]
        print(f"Paired {len(self.rgb_files)} RGB and thermal images.")
        if len(self.rgb_files) == 0:
            raise ValueError("No matching RGB and thermal image pairs found!")
        self.transform_rgb = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
        ])
        self.transform_thermal = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.rgb_files)

    def __getitem__(self, idx):
        rgb_path = os.path.join(self.rgb_dir, self.rgb_files[idx])
        thermal_path = os.path.join(self.thermal_dir, self.thermal_files[idx])
        rgb = Image.open(rgb_path).convert('RGB')
        thermal = Image.open(thermal_path).convert('L')
        rgb = self.transform_rgb(rgb)
        thermal = self.transform_thermal(thermal)
        return rgb, thermal

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        for feature in features:
            self.downs.append(self.conv_block(in_channels, feature))
            in_channels = feature
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(self.conv_block(feature*2, feature))
        self.bottleneck = self.conv_block(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def conv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = nn.MaxPool2d(2)(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip = skip_connections[idx//2]
            if x.shape != skip.shape:
                x = T.functional.resize(x, skip.shape[2:])
            x = torch.cat((skip, x), dim=1)
            x = self.ups[idx+1](x)
        return self.final_conv(x)

In [8]:
# Training loop for the U-Net model with diagnostics and quick test mode
def train_model(
    rgb_dir, thermal_dir, matched_rgb, matched_thermal, epochs=10, batch_size=8, lr=5e-4, img_size=256, save_path='unet_rgb2thermal.pth', max_samples=100
):
    import time
    from tqdm import tqdm

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = FLIRRGB2ThermalDataset(rgb_dir, thermal_dir, img_size, matched_rgb=matched_rgb, matched_thermal=matched_thermal)
    if max_samples is not None:
        indices = list(range(min(max_samples, len(dataset))))
        from torch.utils.data import Subset
        dataset = Subset(dataset, indices)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    model = UNet().to(device)
    criterion = nn.L1Loss()  # Try nn.MSELoss() for even smoother results
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_loss = float('inf')
    train_losses = []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        start_time = time.time()
        for i, (rgb, thermal) in enumerate(tqdm(loader, desc=f"Epoch {epoch+1}")):
            rgb, thermal = rgb.to(device), thermal.to(device)
            pred = model(rgb)
            loss = criterion(pred, thermal)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * rgb.size(0)
            if (i+1) % 2 == 0:
                print(f"  Batch {i+1}/{len(loader)} - Batch Loss: {loss.item():.4f}")
        epoch_loss = running_loss / len(dataset)
        train_losses.append(epoch_loss)
        elapsed = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f} - Time: {elapsed:.1f}s")
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), save_path)
    print("Training complete. Best model saved.")
    return train_losses

# Example usage:
#train_losses = train_model(train_rgb_dir, train_thermal_dir, matched_rgb, matched_thermal, epochs=10, batch_size=8, lr=5e-4, max_samples=100)

In [None]:
# Visualize matched RGB and thermal image pairs in batches of 100
import matplotlib.pyplot as plt
from PIL import Image

def show_matched_images(rgb_dir, thermal_dir, matched_rgb, matched_thermal, batch_size=100):
    total = len(matched_rgb)
    for start in range(0, total, batch_size):
        end = min(start + batch_size, total)
        print(f"Showing images {start} to {end-1}")
        fig, axes = plt.subplots(end - start, 2, figsize=(6, (end - start) * 2))
        if end - start == 1:
            axes = [axes]
        for i, idx in enumerate(range(start, end)):
            rgb_img = Image.open(os.path.join(rgb_dir, matched_rgb[idx]))
            thermal_img = Image.open(os.path.join(thermal_dir, matched_thermal[idx]))
            axes[i][0].imshow(rgb_img)
            axes[i][0].set_title(f"RGB: {matched_rgb[idx]}")
            axes[i][0].axis('off')
            axes[i][1].imshow(thermal_img, cmap='gray')
            axes[i][1].set_title(f"Thermal: {matched_thermal[idx]}")
            axes[i][1].axis('off')
        plt.tight_layout()
        plt.show()

# Example: Show first 100 matched pairs
#show_matched_images(train_rgb_dir, train_thermal_dir, matched_rgb, matched_thermal, batch_size=100)

In [None]:
# Retrain the model from scratch
train_losses0 = train_model(
    train_rgb_dir, 
    train_thermal_dir, 
    matched_rgb, 
    matched_thermal, 
    epochs=10, 
    batch_size=8, 
    lr=5e-4, 
    max_samples=1000)


In [None]:
# Retrain the model from scratch
train_losses = train_model(
    train_rgb_dir, 
    train_thermal_dir, 
    matched_rgb, 
    matched_thermal, 
    epochs=10, 
    batch_size=8, 
    lr=5e-4, 
    max_samples=None)

## Validation, Metrics, and Visualization

Let's evaluate the model on a validation set, compute metrics, and visualize predictions vs. ground truth.

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error

# Use validation set from val/RGB and val/thermal_8_bit folders directly
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
model.load_state_dict(torch.load('unet_rgb2thermal.pth', map_location=device))
model.eval()

val_rgb_dir = os.path.join(path, "FLIR_ADAS_1_3", "val", "RGB")
val_thermal_dir = os.path.join(path, "FLIR_ADAS_1_3", "val", "thermal_8_bit")

val_matched_rgb, val_matched_thermal, _ = get_matched_files_by_prefix(val_rgb_dir, val_thermal_dir)

val_dataset = FLIRRGB2ThermalDataset(val_rgb_dir, val_thermal_dir, img_size=256, matched_rgb=val_matched_rgb, matched_thermal=val_matched_thermal)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

mae_list = []
mse_list = []
outputs = []
gts = []

with torch.no_grad():
    for rgb, thermal in val_loader:
        rgb, thermal = rgb.to(device), thermal.to(device)
        pred = model(rgb)
        pred_np = pred.squeeze().cpu().numpy()
        gt_np = thermal.squeeze().cpu().numpy()
        outputs.append(pred_np)
        gts.append(gt_np)
        mae_list.append(mean_absolute_error(gt_np.flatten(), pred_np.flatten()))
        mse_list.append(mean_squared_error(gt_np.flatten(), pred_np.flatten()))

print(f"Validation MAE: {sum(mae_list)/len(mae_list):.4f}")
print(f"Validation MSE: {sum(mse_list)/len(mse_list):.4f}")


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Binarize outputs and ground truth for classification metrics
threshold = 0.5
all_preds = []
all_gts = []

for pred_np, gt_np in zip(outputs, gts):
    pred_bin = (pred_np > threshold).astype(int).flatten()
    gt_bin = (gt_np > threshold).astype(int).flatten()
    all_preds.extend(pred_bin)
    all_gts.extend(gt_bin)

accuracy = accuracy_score(all_gts, all_preds)
precision = precision_score(all_gts, all_preds, zero_division=0)
recall = recall_score(all_gts, all_preds, zero_division=0)
f1 = f1_score(all_gts, all_preds, zero_division=0)

print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Validation Precision: {precision:.4f}")
print(f"Validation Recall: {recall:.4f}")
print(f"Validation F1 Score: {f1:.4f}")

In [None]:
# Visualize a few predictions vs. ground truth
num_show = 3
plt.figure(figsize=(10, num_show * 3))
for i in range(num_show):
    plt.subplot(num_show, 2, 2*i+1)
    plt.imshow(gts[i], cmap='gray')
    plt.title('Ground Truth')
    plt.axis('off')
    plt.subplot(num_show, 2, 2*i+2)
    plt.imshow(outputs[i], cmap='gray')
    plt.title('Predicted')
    plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Plot training loss
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.show()

In [None]:
# Save the final model (if not already done)
torch.save(model.state_dict(), 'unet_rgb2thermal_final.pth')
print("Final model saved as unet_rgb2thermal_final.pth")