In [6]:
import os
import math
from abc import abstractmethod

from PIL import Image
import requests
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import pickle
import pandas as pd
from PIL import Image
import cv2

%matplotlib inline

## Unet and Gaussian Diffusion

In [3]:
def timestep_embedding(timesteps, dim, max_period=1000):
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

class TimestepBlock(nn.Module):
    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """
    def forward(self, x, t_emb, c_emb, mask):
        for layer in self:
            if(isinstance(layer, TimestepBlock)):
                x = layer(x, t_emb, c_emb, mask)
            else:
                x = layer(x)
            return x

def norm_layer(channels):
    return nn.GroupNorm(32, channels)

class ResidualBlock(TimestepBlock):
    def __init__(self, in_channels, out_channels, time_channels, cond_channels, dropout):
        super().__init__()
        self.conv1 = nn.Sequential(
            norm_layer(in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        )
        
        self.time_emb = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_channels, out_channels)
        )
        
        self.cond_conv = nn.Sequential(
            nn.Conv2d(cond_channels, out_channels, kernel_size=3, padding=1),
            nn.SiLU()
        )
        
        self.conv2 = nn.Sequential(
            norm_layer(out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        )
        
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
    
    def forward(self, x, t, cond_img, mask):
        h = self.conv1(x)
        emb_t = self.time_emb(t)
        emb_cond = self.cond_conv(cond_img) * mask[:, None, None, None]
        h += (emb_t[:, :, None, None] + emb_cond)
        h = self.conv2(h)
        
        return h + self.shortcut(x)
    
class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=1):
        super().__init__()
        self.num_heads = num_heads
        assert channels % num_heads == 0
        
        self.norm = norm_layer(channels)
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1)
        
    def forward(self, x):
        B, C, H, W = x.shape
        qkv = self.qkv(self.norm(x))
        q, k, v = qkv.reshape(B *  self.num_heads, -1, H * W).chunk(3, dim=1)
        scale = 1. / math.sqrt(math.sqrt(C // self.num_heads))
        attn = torch.einsum("bct,bcs->bts", q * scale, k * scale)
        attn = attn.softmax(dim=-1)
        h = torch.einsum("bts,bcs->bct", attn, v)
        h = h.reshape(B, -1, H, W)
        h = self.proj(h)
        return h + x
    
class UpSample(nn.Module):
    def __init__(self, channels, use_conv):
        super().__init__()
        self.use_conv = use_conv
        if use_conv:
            self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x

class DownSample(nn.Module):
    def __init__(self, channels, use_conv):
        super().__init__()
        self.use_conv = use_conv
        if use_conv:
            self.op = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=2)
        else:
            self.op = nn.AvgPool2d(stride=2)
    
    def forward(self, x):
        return self.op(x)
    
class Unet(nn.Module):
    def __init__(self, 
                 in_channels=2,
                 cond_channels=1,
                 model_channels=128,
                 out_channels=2,
                 num_res_blocks=2, 
                 attention_resolutions=(8, 16),
                 dropout=0,
                 channel_mult=(1, 2, 2, 2),
                 conv_resample=True,
                 num_heads=4):
        super().__init__()
        self.in_channels = in_channels
        self.cond_channels = cond_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions,
        self.dropout = dropout
        self.channel_mult = channel_mult,
        self.conv_resample = conv_resample
        self.num_heads = num_heads
        
        # time embedding
        time_emb_dim = model_channels * 4
        self.time_emb = nn.Sequential(
            nn.Linear(model_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        # down blocks
        self.down_blocks = nn.ModuleList([
            TimestepEmbedSequential(nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1))
        ])
        down_block_channels = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [ResidualBlock(ch, model_channels * mult, time_emb_dim, cond_channels, dropout)]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(AttentionBlock(ch, num_heads))
                self.down_blocks.append(TimestepEmbedSequential(*layers))
                down_block_channels.append(ch)
            if level != len(channel_mult) - 1:
                self.down_blocks.append(TimestepEmbedSequential(DownSample(ch, conv_resample)))
                down_block_channels.append(ch)
                ds *= 2
        
        # middle blocks
        self.middle_blocks = TimestepEmbedSequential(
            ResidualBlock(ch, ch, time_emb_dim, cond_channels, dropout),
            AttentionBlock(ch, num_heads),
            ResidualBlock(ch, ch, time_emb_dim, cond_channels, dropout)
        )
        
        # up blocks
        self.up_blocks = nn.ModuleList([])
        for level, mult in enumerate(channel_mult[::-1]):
            for i in range(num_res_blocks + 1):
                layers = [
                    ResidualBlock(ch + down_block_channels.pop(), model_channels * mult, time_emb_dim, cond_channels, dropout)]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(AttentionBlock(ch, num_heads))
                if level != len(channel_mult) - 1 and i == num_res_blocks:
                    layers.append(UpSample(ch, conv_resample))
                    ds //= 2
                self.up_blocks.append(TimestepEmbedSequential(*layers))
                
        self.out = nn.Sequential(
            norm_layer(ch),
            nn.SiLU(),
            nn.Conv2d(ch, out_channels, kernel_size=3, padding=1)
        )
    
    def forward(self, x, timesteps, cond_img, mask):
        """
        Apply the model to an input batch.
        :param x: an [N x C x H x W] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param cond_img: a [N x cond_C x H x W] Tensor of conditional images.
        :param mask: a 1-D batch of conditioned/unconditioned.
        :return: an [N x C x ...] Tensor of outputs.
        """
        hs = []
        # time step embedding
        t_emb = self.time_emb(timestep_embedding(timesteps, dim=self.model_channels))
        
        # down step
        h = x
        for module in self.down_blocks:
            if cond_img.shape[2:] != h.shape[2:]:
                cond_img = F.interpolate(cond_img, size=h.shape[2:], mode='nearest')
            h = module(h, t_emb, cond_img, mask)
            hs.append(h)
        # mid stage
        if cond_img.shape[2:] != h.shape[2:]:
            cond_img = F.interpolate(cond_img, size=h.shape[2:], mode='nearest')
        h = self.middle_blocks(h, t_emb, cond_img, mask)
        
        # up stage
        for module in self.up_blocks:
            h_skip = hs.pop()
            
            if h.shape[2:] != h_skip.shape[2:]:
                h = F.interpolate(h, size=h_skip.shape[2:], mode='nearest')

            if cond_img.shape[2:] != h.shape[2:]:
                cond_img = F.interpolate(cond_img, size=h.shape[2:], mode='nearest')

            cat_in = torch.cat([h, h_skip], dim=1)
            h = module(cat_in, t_emb, cond_img, mask)
        
        return self.out(h)

In [4]:
# beta schedule
def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

def sigmoid_beta_schedule(timesteps):
    betas = torch.linspace(-6, 6, timesteps)
    betas = torch.sigmoid(betas) / (betas.max() - betas.min()) * (0.02 - betas.min()) / 10
    return betas

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

In [None]:
class GaussianDiffusion:
    def __init__(
        self,
        timesteps=1000,
        beta_schedule='linear',
    ):
        self.timesteps = timesteps
        
        if beta_schedule == 'linear':
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(timesteps)
        elif beta_schedule == 'sigmoid':
            betas = sigmoid_beta_schedule(timesteps)
        else:
            raise ValueError(f'Unknown beta schedule {beta_schedule}')
        
        self.betas = betas
        
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
        
        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
        
        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]])
        )
        
        self.posterior_mean_coef1 = (
            self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
        )
    
    # get the param of given timestep t
    def _extract(self, a, t, x_shape):
        batch_size = t.shape[0]
        out = a.to(t.device).gather(0, t).float()
        out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
        return out
    
    # forward diffusion : q(x_t | x_0)
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    
    # mean and variance of q(x_t | x_0)
    def q_mean_variance(self, x_start, t):
        mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance
    
    # mean and variance of diffusion posterior: q(x_{t-1} | x_t, x_0)
    def q_posterior_mean_variance(self, x_start, x_t, t):
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
    
    # compute x_0 from x_t and pred noise: reverse of q_sample
    def predict_start_from_noise(self, x_t, t, noise):
        return (
            self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )
    
    # compute predicted mean and variance of p(x_{t-1} | x_t) 
    def p_mean_variance(self, model, x_t, t, cond_img, w, clip_denoised=True):
        device = next(model.parameters()).device
        batch_size = x_t.shape[0]
        
        # noise prediction from model
        pred_noise_cond = model(x_t, t, cond_img, torch.ones(batch_size).int().to(device))
        pred_noise_uncond = model(x_t, t, cond_img, torch.zeros(batch_size).int().to(device))
        pred_noise = (1 + w) * pred_noise_cond - w * pred_noise_uncond
        
        # get predicted x_0
        x_recon = self.predict_start_from_noise(x_t, t, pred_noise)
        if clip_denoised:
            x_recon = torch.clamp(x_recon, min=-1., max=1.)
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_recon, x_t, t)
        
        return model_mean, posterior_variance, posterior_log_variance
    
    # denoise step: sample x_{t-1} from x_t and pred noise
    @torch.no_grad()
    def p_sample(self, model, x_t, t, cond_img, w, clip_denoised=True):
        # pred mean and variance
        model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t, cond_img, w, clip_denoised=clip_denoised)
        
        noise = torch.randn_like(x_t)
        # no noise when t = 0 
        nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
        # compute x_{t-1}
        pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        return pred_img
    
    # denoise : reverse diffusion
    @torch.no_grad()
    def p_sample_loop(self, model, shape, cond_img, w=2, clip_denoised=True):
        batch_size = shape[0]
        device = next(model.parameters()).device
        
        # start from pure noise
        img = torch.randn(shape, device=device)
        imgs = []
        for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps):
            img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long), cond_img, w, clip_denoised)
            imgs.append(img.cpu().numpy())
        return imgs
    
    # sample new images
    @torch.no_grad
    def sample(self, model, image_size, cond_img, batch_size=8, channels=3, w=2, clip_denoised=True):
        return self.p_sample_loop(model, (batch_size, channels, image_size, image_size), cond_img, w, clip_denoised)
    
    # use ddim to sample
    @torch.no_grad()
    def ddim_sample(
        self,
        model,
        image_size,
        cond_img,
        batch_size=8,
        channels=3,
        ddim_timesteps=50,
        w=2,
        ddim_discr_method="uniform",
        ddim_eta=0.0,
        clip_denoised=True):
        
        # make ddim timestep sequence
        if ddim_discr_method == 'uniform':
            c = self.timesteps // ddim_timesteps
            ddim_timestep_seq = np.asarray(list(range(0, self.timesteps, c)))
        elif ddim_discr_method == 'quad':
            ddim_timestep_seq = (
                (np.linspace(0, np.sqrt(self.timesteps * .8), ddim_timesteps)) ** 2
            ).astype(int)
        else:
            raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
        # add one to get the final alpha values right (the ones from first scale to data during sampling)
        ddim_timestep_seq = ddim_timestep_seq + 1
        # previous sequence
        ddim_timestep_prev_seq = np.append(np.array([0]), ddim_timestep_seq[:-1])
        
        device = next(model.parameters()).device
        
        # start from pure noise (for each example in the batch)
        sample_img = torch.randn((batch_size, channels, image_size, image_size), device=device)
        seq_img = [sample_img.cpu().numpy()]   
        
        for i in tqdm(reversed(range(0, ddim_timesteps)), desc='sampling loop time step', total=ddim_timesteps):
            t = torch.full((batch_size,), ddim_timestep_seq[i], device=device, dtype=torch.long)
            prev_t = torch.full((batch_size,), ddim_timestep_prev_seq[i], device=device, dtype=torch.long)
            
            # 1. get current and previous alpha_cumprod
            alpha_cumprod_t = self._extract(self.alphas_cumprod, t, sample_img.shape)
            alpha_cumprod_t_prev = self._extract(self.alphas_cumprod, prev_t, sample_img.shape)
    
            # 2. predict noise using model
            pred_noise_cond = model(sample_img, t, cond_img, torch.ones(batch_size).int().cuda())
            pred_noise_uncond = model(sample_img, t, cond_img, torch.zeros(batch_size).int().cuda())
            pred_noise = (1+w)*pred_noise_cond - w*pred_noise_uncond
            
            # 3. get the predicted x_0
            pred_x0 = (sample_img - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t)
            if clip_denoised:
                pred_x0 = torch.clamp(pred_x0, min=-1., max=1.)
            
            # 4. compute variance: "sigma_t(η)" -> see formula (16)
            # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
            sigmas_t = ddim_eta * torch.sqrt(
                (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_t_prev))
            
            # 5. compute "direction pointing to x_t" of formula (12)
            pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t**2) * pred_noise
            
            # 6. compute x_{t-1} of formula (12)
            x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * torch.randn_like(sample_img)

            sample_img = x_prev

        return sample_img.cpu().numpy()
    
    # compute train losses
    def train_losses(self, model, x_start, t, cond_img, mask_c):
        # generate random noise
        noise = torch.randn_like(x_start)
        # get x_t
        x_noisy = self.q_sample(x_start, t, noise=noise)
        predicted_noise = model(x_noisy, t, cond_img, mask_c)
        loss = F.mse_loss(noise, predicted_noise)
        return loss 

## Dataset

In [69]:
class ConditionalImageDataset(Dataset):
    def __init__(self, data_dir, transform=None, conditional_offset=5):
        self.data_dir = data_dir
        self.transform = transform
        self.conditional_offset = conditional_offset
        self.regression_csv_path = os.path.join(data_dir, "regression_params.csv")
        self.cond_images = []
        self.target_images = []
        self.reg_data = None
        self._load_data()

    def _load_data(self):
        files = sorted([os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir) if f.endswith('.mpy')])
        for file in files:
            with open(file, 'rb') as f:
                images = pickle.load(f)
                if isinstance(images, list):
                    images = np.array(images)
                    
                for img_idx in range(len(images) - self.conditional_offset):
                    self.cond_images.append(images[img_idx])
                    self.target_images.append(images[img_idx + self.conditional_offset])
                    
        headers = ["p_h", "a", "c_1", "c_2"]
        self.reg_data = pd.read_csv(self.regression_csv_path, names=headers, index_col=0)

    def __len__(self):
        return len(self.cond_images)

    def __getitem__(self, idx):
        # collection_idx, image_idx = self._get_indices(idx)
        cond_image = self.cond_images[idx]
        image = self.target_images[idx]
        reg_params = self.reg_data.loc[idx // 15].to_numpy()

        if self.transform:
            image = self.transform(image)
            cond_image = self.transform(cond_image)

        return image, cond_image, reg_params

In [82]:
batch_size = 32

DATA_DIR = "C:/Users/Anirbit/Desktop/MSc/Ind Project/Msc-Project/data/simulated_bin_frames"
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(64)
])

dataset = ConditionalImageDataset(DATA_DIR, transform=transform)

# Split dataset into training and testing
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

## Regression Model

In [55]:
# Define the CNN model for image feature extraction
class CNNRegressionModel(nn.Module):
    def __init__(self):
        super(CNNRegressionModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 4)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 128 * 8 * 8)  # Flatten
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [56]:
model = CNNRegressionModel()
model.to(device)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [78]:
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for _, cond_img, target_params in train_loader:
        
        # print(torch.tensor(np.array(target_params)))
        optimizer.zero_grad()
        
        cond_img = cond_img.float().to(device)
        target_params = target_params.float().to(device)
        
        outputs = model(cond_img)
        loss = criterion(outputs, target_params)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

print("Training Finished")

Epoch 1/20, Loss: 0.009059743077324873
Epoch 2/20, Loss: 0.004588235846974633
Epoch 3/20, Loss: 0.004109623492695391
Epoch 4/20, Loss: 0.0035810038997707043
Epoch 5/20, Loss: 0.0031332837692885237
Epoch 6/20, Loss: 0.002429625606799329
Epoch 7/20, Loss: 0.0018854059572649103
Epoch 8/20, Loss: 0.001355365984967317
Epoch 9/20, Loss: 0.0008236551002482884
Epoch 10/20, Loss: 0.0005159064345538023
Epoch 11/20, Loss: 0.0003175455732906068
Epoch 12/20, Loss: 0.00019109417983351952
Epoch 13/20, Loss: 0.00016626493875678122
Epoch 14/20, Loss: 0.00012522021676307884
Epoch 15/20, Loss: 0.00010506114598353055
Epoch 16/20, Loss: 8.819940243979958e-05
Epoch 17/20, Loss: 7.10658256469866e-05
Epoch 18/20, Loss: 5.319180490914732e-05
Epoch 19/20, Loss: 4.031423858660591e-05
Epoch 20/20, Loss: 3.981428174235837e-05
Training Finished


In [87]:
# Define the evaluation function
def evaluate_model(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    criterion = nn.MSELoss()
    
    with torch.no_grad():  # Disable gradient calculation
        for _, cond_img, reg_params in test_loader:
            cond_img = cond_img.float().to(device)
            reg_params = reg_params.float().to(device)
            
            outputs = model(cond_img)
            
            print(f"Model Pred: {outputs.cpu().numpy()[0]}, Ground Truth: {outputs.cpu().numpy()[0]}")
            
            loss = criterion(outputs, reg_params)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(test_loader)
    return avg_loss

# Example usage
# Assuming you have a DataLoader for the test dataset named test_loader
# and a trained model named model

test_loss = evaluate_model(model, test_loader)
print(f'Test Loss: {test_loss}')

Model Pred: [0.25104594 0.05430795 0.0489934  0.06473117], Ground Truth: [0.25104594 0.05430795 0.0489934  0.06473117]
Model Pred: [0.23104388 0.10700779 0.0708842  0.35263765], Ground Truth: [0.23104388 0.10700779 0.0708842  0.35263765]
Model Pred: [0.25993335 0.13719802 0.04459206 0.11923321], Ground Truth: [0.25993335 0.13719802 0.04459206 0.11923321]
Model Pred: [0.20934671 0.05516015 0.05602372 0.02707585], Ground Truth: [0.20934671 0.05516015 0.05602372 0.02707585]
Model Pred: [0.20579284 0.05028914 0.05309315 0.01605929], Ground Truth: [0.20579284 0.05028914 0.05309315 0.01605929]
Model Pred: [0.3032931  0.13428353 0.04280767 0.23829897], Ground Truth: [0.3032931  0.13428353 0.04280767 0.23829897]
Model Pred: [0.21421385 0.12240772 0.10338275 0.12605475], Ground Truth: [0.21421385 0.12240772 0.10338275 0.12605475]
Model Pred: [0.28019774 0.06693217 0.04069465 0.364071  ], Ground Truth: [0.28019774 0.06693217 0.04069465 0.364071  ]
Model Pred: [0.2616902  0.05692882 0.10232569 0.