## Environment Setting

In [None]:
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from pathlib import Path
import torch.nn as nn
import torch.nn.functional as F


ModuleNotFoundError: No module named 'torch'

## DataLoader

### Load Image

In [None]:
class MyDataset(Dataset):
    def __init__(self, 
                 root, 
                 if_train = True,
                 data_dir = "train",
                 color_dir = "color",
                 depth_dir = "depth_raw",
                 rgb_name = "rgb.png",
                 depth_name="depth_raw.png",
                 csv_name = "nutrition5k_train.csv",
                 transform = False,
                 ):
        self.root = Path(root)
        self.data = self.root / data_dir
        self.color_dir  = self.data / color_dir
        self.depth_dir = self.data / depth_dir
        self.rgb_name = rgb_name
        self.depth_name = depth_name
        self.if_train = if_train
        self.transform = transform

        if self.if_train:
            df = pd.read_csv(self.root / csv_name)
            self.id2cal = {str(r["ID"]) : float(r["Value"]) for _, r in df.iterrow()}

        rgb_paths = sorted((p / rgb_name for p in self.color_dir.glob("dish_*") if (p / rgb_name).exists()))
        if not rgb_paths:
            raise RuntimeError(f"Found 0 images in {self.color_dir}.")
        
        self.samples = []
        if self.if_train:
            for rgb_path in rgb_paths:
                dish_id = rgb_path.parent.name
                if dish_id not in self.id2cal:
                    print(f"Warning: {dish_id} not found in CSV.")
                depth_path = None
                if self.depth_dir is not None:
                    depth_path = self.depth_dir / dish_id / self.depth_name

            self.samples.append((rgb_path, depth_path, self.id2cal[dish_id]))
        else:
            for rgb_path in rgb_paths:
                dish_id = rgb_path.parent.name
                depth_path = None
                if self.depth_dir is not None:
                    depth_path = self.depth_dir / dish_id / self.depth_name

                self.samples.append((rgb_path, depth_path, None))


    def __len__(self):
            return len(self.samples)
    
    def __getitem__(self, idx):
        rgb_path, depth_path, cal = self.samples[idx]
        rgb = Image.open(rgb_path).convert("RGB")
        depth = Image.open(depth_path).convert("L") if depth_path and depth_path.exists() else None

        if self.if_train == True:
            cal = torch.tensor(cal, dtype=torch.float32)
            return rgb, depth, cal
        else:
            return rgb, depth
            
        

In [None]:
train_set = MyDataset(root=Path("."), if_train=True, transform=False)
test_set = MyDataset(root=Path("."), data_dir="test", if_train=False, transform=False)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True,
                          num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=32, shuffle=False,
                          num_workers=4, pin_memory=True)


# Model Definition

In [None]:
def conv_block(c_in, c_out, k=5, s=1, p=1):
    return nn.Sequential(
        nn.Conv2d(c_in, c_out, kernel_size=k, stride=s, padding=p, bias=False),
        nn.BatchNorm2d(c_out),
        nn.ReLU(inplace=True)
    )

class RGBBranch(nn.Module):
    def __init__(self, in_ch=3):
        super().__init__()
        self.stem = nn.Sequential(
            conv_block(in_ch, 32),                
            conv_block(32, 32),
            nn.MaxPool2d(2),                      
            conv_block(32, 64),
            conv_block(64, 64),
            nn.MaxPool2d(2),                      
            conv_block(64, 128),
            nn.MaxPool2d(2),                      
            conv_block(128, 256),
            nn.MaxPool2d(2),                     
        )
        self.gap = nn.AdaptiveAvgPool2d(1)       

    def forward(self, x):
        x = self.stem(x)
        x = self.gap(x).flatten(1)               
        return x

class DepthBranch(nn.Module):
    def __init__(self, in_ch=1):
        super().__init__()
        self.stem = nn.Sequential(
            conv_block(in_ch, 16),
            conv_block(16, 16),
            nn.MaxPool2d(2),                      
            conv_block(16, 32),
            nn.MaxPool2d(2),                      
            conv_block(32, 64),
            nn.MaxPool2d(2),                      
            conv_block(64, 128),
            nn.MaxPool2d(2),                      
        )
        self.gap = nn.AdaptiveAvgPool2d(1)       

    def forward(self, x):
        x = self.stem(x)
        x = self.gap(x).flatten(1)              
        return x

class RGBDNet(nn.Module):

    def __init__(self, use_dropout=True):
        super().__init__()
        self.rgb = RGBBranch(3)
        self.depth = DepthBranch(1)

        fusion_dim = 256 + 128
        mlp = [
            nn.Linear(fusion_dim, 128),
            nn.ReLU(inplace=True)
        ]
        if use_dropout:
            mlp.append(nn.Dropout(0.1))
        mlp += [
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 1)
        ]
        self.head = nn.Sequential(*mlp)

    def forward(self, rgb, depth):
        f_rgb = self.rgb(rgb)       
        f_d   = self.depth(depth)    
        f     = torch.cat([f_rgb, f_d], dim=1)
        out   = self.head(f).squeeze(1)  
        return out



# Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RGBDNet()  
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.L2Loss() 

for epoch in range(10):
    # ---- Train ----
    model.train()
    for rgb, depth, cal, _ids in train_loader:
        rgb, depth, cal = rgb.to(device), depth.to(device), cal.to(device)
        pred = model(rgb, depth).squeeze(1) 
        loss = criterion(pred, cal)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()