In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("deepnewbie/flir-thermal-images-dataset")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/deepnewbie/flir-thermal-images-dataset?dataset_version_number=1...


 27%|██▋       | 4.16G/15.3G [13:58<33:51, 5.87MB/s]  

In [None]:
# List files and folders in the dataset directory
import os

files = os.listdir(path)
print("Files and folders in dataset root:")
for f in files:
    print(f)

In [None]:
# Example: Set RGB and thermal image directories based on dataset structure
rgb_dir = os.path.join(path, "FLIR_ADAS_v2", "images_rgb_8_bit")
thermal_dir = os.path.join(path, "FLIR_ADAS_v2", "images_thermal_8_bit")

print("RGB directory:", rgb_dir)
print("Thermal directory:", thermal_dir)

In [None]:
# Install required libraries if not already installed
!pip install torch torchvision pillow kagglehub

In [None]:
# Define the custom dataset and U-Net model for RGB to thermal translation
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as T

class FLIRRGB2ThermalDataset(Dataset):
    def __init__(self, rgb_dir, thermal_dir, img_size=256):
        self.rgb_dir = rgb_dir
        self.thermal_dir = thermal_dir
        self.img_size = img_size
        self.rgb_files = sorted([f for f in os.listdir(rgb_dir) if f.endswith('.jpg') or f.endswith('.png')])
        self.thermal_files = sorted([f for f in os.listdir(thermal_dir) if f.endswith('.jpg') or f.endswith('.png')])
        self.transform_rgb = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
        ])
        self.transform_thermal = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.rgb_files)

    def __getitem__(self, idx):
        rgb_path = os.path.join(self.rgb_dir, self.rgb_files[idx])
        thermal_path = os.path.join(self.thermal_dir, self.thermal_files[idx])
        rgb = Image.open(rgb_path).convert('RGB')
        thermal = Image.open(thermal_path).convert('L')
        rgb = self.transform_rgb(rgb)
        thermal = self.transform_thermal(thermal)
        return rgb, thermal

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        for feature in features:
            self.downs.append(self.conv_block(in_channels, feature))
            in_channels = feature
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(self.conv_block(feature*2, feature))
        self.bottleneck = self.conv_block(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def conv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = nn.MaxPool2d(2)(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip = skip_connections[idx//2]
            if x.shape != skip.shape:
                x = T.functional.resize(x, skip.shape[2:])
            x = torch.cat((skip, x), dim=1)
            x = self.ups[idx+1](x)
        return self.final_conv(x)

In [None]:
# Training loop for the U-Net model
def train_model(
    rgb_dir, thermal_dir, epochs=10, batch_size=4, lr=1e-4, img_size=256, save_path='unet_rgb2thermal.pth'
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = FLIRRGB2ThermalDataset(rgb_dir, thermal_dir, img_size)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    model = UNet().to(device)
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for rgb, thermal in loader:
            rgb, thermal = rgb.to(device), thermal.to(device)
            pred = model(rgb)
            loss = criterion(pred, thermal)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * rgb.size(0)
        epoch_loss = running_loss / len(dataset)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}")
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), save_path)
    print("Training complete. Best model saved.")

# Example usage:
# train_model(rgb_dir, thermal_dir, epochs=10)

In [None]:
# Inference: Generate synthetic thermal image from RGB input
def infer(rgb_img_path, model_path, img_size=256):
    import numpy as np
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNet().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    transform = T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
    ])
    img = Image.open(rgb_img_path).convert('RGB')
    inp = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(inp)
        out_img = out.squeeze().cpu().numpy()
        out_img = (out_img * 255).clip(0, 255).astype(np.uint8)
        return Image.fromarray(out_img)

# Example usage:
# result = infer('path/to/sample_rgb.jpg', 'unet_rgb2thermal.pth')
# result.save('synthetic_thermal.png')