In [3]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.models as models

from sklearn.metrics import mean_squared_error
from tqdm import tqdm
import torch.optim as optim

from pathlib import Path

In [4]:
DATASET_MEATADATA_PATH="//kaggle/input/elapsed-thermal-wheel/metadata.csv"
DATASET_IMAGE_PATH="/kaggle/input/elapsed-thermal-wheel/images"

In [5]:
class ElapsedThermalWheelDataset(Dataset):
    def __init__(self, meatadata_file, images_folder, transform=None):
        self.metadata = pd.read_csv(meatadata_file)
        self.images_folder = images_folder
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        image_path = os.path.join(self.images_folder, row['image_id'])
        elapsed_time = row['elapsed_time_seconds']

        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return {"image": image, "elapsed_time": elapsed_time}


In [6]:
transform = transforms.Compose([
    transforms.Resize((244, 244)),
    transforms.ToTensor(),        
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [7]:
dataset = ElapsedThermalWheelDataset(meatadata_file=DATASET_MEATADATA_PATH, images_folder=DATASET_IMAGE_PATH, transform=transform)

In [8]:
class ThermalWheelTimeEstimatorVGG16(nn.Module):
    def __init__(self):
        super(ThermalWheelTimeEstimatorVGG16, self).__init__()

        self.feature_extractor = models.vgg16(pretrained=True)
        
        self.feature_extractor.classifier = nn.Sequential(
            nn.Linear(25088, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        features = self.feature_extractor(x)  
        return features

In [9]:
class ThermalWheelTimeEstimatorResnet(nn.Module):
    def __init__(self):
        super(ThermalWheelTimeEstimatorResnet, self).__init__()
        self.feature_extractor = models.resnet18(pretrained=True)
        
        num_features = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Linear(num_features, 256)
        
        self.regressor = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        features = self.feature_extractor(x)  
        time_prediction = self.regressor(features)
        return time_prediction

In [10]:
def evaluate_model(model, loader, device):
    model.eval()
    val_loss = 0.0
    criterion = nn.MSELoss()

    all_targets = []
    all_outputs = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="[Validation]"):
            images = batch["image"].to(device)
            targets = batch["elapsed_time"].to(device).float()
            
            outputs = model(images)
            loss = criterion(outputs.squeeze(), targets)
            val_loss += loss.item()

            all_targets.extend(targets.cpu().numpy())
            all_outputs.extend(outputs.squeeze().cpu().numpy())
    
    val_loss /= len(loader)
    mse = mean_squared_error(all_targets, all_outputs)
    print(f"Validation MSE: {mse:.4f}")
    return val_loss

In [11]:
def train_model(model, train_loader, val_loader, device, num_epochs=10, learning_rate=1e-4):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            images = batch["image"].to(device)
            targets = batch["elapsed_time"].to(device).float()
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.squeeze(), targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}")
        
        val_loss = evaluate_model(model, val_loader, device)
        print(f"Epoch {epoch+1}, Validation Loss: {val_loss:.4f}")

In [12]:
TRAIN_SIZE = int(0.8 * len(dataset))
VAL_SIZE = len(dataset) - TRAIN_SIZE

In [13]:
train_dataset, val_dataset = random_split(dataset, [TRAIN_SIZE, VAL_SIZE])

In [14]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [15]:
model = ThermalWheelTimeEstimatorResnet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s] 


In [16]:
train_model(model, train_loader, val_loader, device, num_epochs=20, learning_rate=1e-3)

Epoch 1/20 [Train]: 100%|██████████| 5/5 [00:02<00:00,  2.10it/s]


Epoch 1, Train Loss: 69073.5141


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  2.53it/s]


Validation MSE: 96591.7266
Epoch 1, Validation Loss: 70791.9941


Epoch 2/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.82it/s]


Epoch 2, Train Loss: 63768.9492


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.07it/s]


Validation MSE: 44560.7344
Epoch 2, Validation Loss: 36017.4971


Epoch 3/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.81it/s]


Epoch 3, Train Loss: 48924.9563


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.69it/s]


Validation MSE: 27120.3809
Epoch 3, Validation Loss: 19682.8892


Epoch 4/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.41it/s]


Epoch 4, Train Loss: 28053.3473


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.09it/s]


Validation MSE: 117562.5703
Epoch 4, Validation Loss: 93724.8320


Epoch 5/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.96it/s]


Epoch 5, Train Loss: 12095.5879


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.10it/s]


Validation MSE: 44829.7109
Epoch 5, Validation Loss: 30923.9019


Epoch 6/20 [Train]: 100%|██████████| 5/5 [00:00<00:00,  5.03it/s]


Epoch 6, Train Loss: 8014.7273


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.23it/s]


Validation MSE: 48264.7773
Epoch 6, Validation Loss: 29322.6012


Epoch 7/20 [Train]: 100%|██████████| 5/5 [00:00<00:00,  5.06it/s]


Epoch 7, Train Loss: 8462.6335


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.91it/s]


Validation MSE: 28340.2695
Epoch 7, Validation Loss: 24893.8740


Epoch 8/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.93it/s]


Epoch 8, Train Loss: 3550.7928


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


Validation MSE: 2440709.0000
Epoch 8, Validation Loss: 2441085.7500


Epoch 9/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.82it/s]


Epoch 9, Train Loss: 1837.3419


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.05it/s]


Validation MSE: 4056951.7500
Epoch 9, Validation Loss: 4252063.2500


Epoch 10/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.95it/s]


Epoch 10, Train Loss: 3707.2002


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.20it/s]


Validation MSE: 64269.6094
Epoch 10, Validation Loss: 57402.9492


Epoch 11/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.99it/s]


Epoch 11, Train Loss: 3649.1679


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.14it/s]


Validation MSE: 15811.6973
Epoch 11, Validation Loss: 10845.2675


Epoch 12/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.84it/s]


Epoch 12, Train Loss: 3179.9875


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.10it/s]


Validation MSE: 16863.1406
Epoch 12, Validation Loss: 10804.5111


Epoch 13/20 [Train]: 100%|██████████| 5/5 [00:00<00:00,  5.06it/s]


Epoch 13, Train Loss: 2198.8950


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.27it/s]


Validation MSE: 19581.0547
Epoch 13, Validation Loss: 13870.5691


Epoch 14/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


Epoch 14, Train Loss: 1642.5057


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.97it/s]


Validation MSE: 10187.5439
Epoch 14, Validation Loss: 6019.6178


Epoch 15/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.86it/s]


Epoch 15, Train Loss: 1757.0475


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.04it/s]


Validation MSE: 8540.8369
Epoch 15, Validation Loss: 6255.7522


Epoch 16/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


Epoch 16, Train Loss: 1555.5109


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.84it/s]


Validation MSE: 10696.5527
Epoch 16, Validation Loss: 6637.7277


Epoch 17/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.94it/s]


Epoch 17, Train Loss: 1627.5022


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.94it/s]


Validation MSE: 8961.1631
Epoch 17, Validation Loss: 5527.7658


Epoch 18/20 [Train]: 100%|██████████| 5/5 [00:00<00:00,  5.04it/s]


Epoch 18, Train Loss: 1863.2157


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.19it/s]


Validation MSE: 7769.9609
Epoch 18, Validation Loss: 5019.6016


Epoch 19/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.92it/s]


Epoch 19, Train Loss: 1392.8962


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  3.82it/s]


Validation MSE: 9258.2812
Epoch 19, Validation Loss: 5572.9091


Epoch 20/20 [Train]: 100%|██████████| 5/5 [00:01<00:00,  4.96it/s]


Epoch 20, Train Loss: 1915.2431


[Validation]: 100%|██████████| 2/2 [00:00<00:00,  4.13it/s]

Validation MSE: 13871.3613
Epoch 20, Validation Loss: 8102.9796





In [17]:
model.eval()
for image_path in Path("/kaggle/input/test-time-thermal-wheel").glob("*.png"):
    image = Image.open(image_path).convert("RGB")
    tr_image = transform(image)

    with torch.inference_mode():
        print(f"name = {image_path} prediction = {model(torch.stack([tr_image]).to(device).float()).squeeze()}")

name = /kaggle/input/test-time-thermal-wheel/img_20241204_215825.png prediction = 16.323925018310547
name = /kaggle/input/test-time-thermal-wheel/img_20241204_215745.png prediction = 62.69554138183594
name = /kaggle/input/test-time-thermal-wheel/img_20241204_215814.png prediction = 19.826839447021484
name = /kaggle/input/test-time-thermal-wheel/img_20241205_135931.png prediction = 86.96996307373047
name = /kaggle/input/test-time-thermal-wheel/img_20241205_140023.png prediction = 47.446414947509766
name = /kaggle/input/test-time-thermal-wheel/img_20241204_215905.png prediction = 33.95003128051758
name = /kaggle/input/test-time-thermal-wheel/img_20241204_220333.png prediction = 29.63716697692871
name = /kaggle/input/test-time-thermal-wheel/img_20241204_220313.png prediction = 79.51371765136719
name = /kaggle/input/test-time-thermal-wheel/img_20241204_215835.png prediction = 40.6210823059082
name = /kaggle/input/test-time-thermal-wheel/img_20241204_215924.png prediction = 52.5791549682617