In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm  

class FrameInterpolationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(6, 64, 3, 1, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 128, 3, 1, 1),
                nn.ReLU(inplace=True)
            ) for _ in range(3)
        ])
        self.resize = nn.Upsample(size=(90, 160), mode='bilinear', align_corners=False)
        self.fusion_conv = nn.Conv2d(384, 128, 3, 1, 1)
        self.upsample_conv = nn.ConvTranspose2d(128, 3, 3, 2, 1, 1, 1)

    def forward(self, x):
        feature_maps = [extractor(x) for extractor in self.feature_extractor]
        feature_maps_resized = [self.resize(fm) for fm in feature_maps]
        x = torch.cat(feature_maps_resized, 1)
        x = F.relu(self.fusion_conv(x))
        x = self.upsample_conv(x)
        return x

class FrameInterpolationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.num_frames_per_video = 190
        self.start_frame_number = 100000  

    def __len__(self):
        
        return self.num_frames_per_video * 25  

    def __getitem__(self, idx):
        frame_idx = idx % self.num_frames_per_video
        video_idx = idx // self.num_frames_per_video
        frame_number = self.start_frame_number + frame_idx

        # Create paths for the frames
        frame_folder = self.root_dir
        frame1_path = os.path.join(frame_folder, f"RLCAFTCONF-C0_{frame_number:06d}.jpeg")
        frame2_path = os.path.join(frame_folder, f"RLCAFTCONF-C0_{frame_number + 1:06d}.jpeg")
        target_path = os.path.join(frame_folder, f"RLCAFTCONF-C0_{frame_number + 2:06d}.jpeg")

        
        frame1, frame2, target = map(Image.open, (frame1_path, frame2_path, target_path))

        
        if self.transform:
            frame1, frame2, target = map(self.transform, (frame1, frame2, target))

        return frame1, frame2, target

model = FrameInterpolationModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


print("Loading dataset...")
dataset = FrameInterpolationDataset(root_dir='C:\\Users\\adhik\\Frame Interpolation\\EPFL-RLC_dataset\\frames\\cam0', transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) 
print("Dataset loaded.")


print("Training model...")
num_epochs = 1
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    epoch_loss = 0
    
    for batch_idx, (frame1, frame2, target) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}")):
        inputs = torch.cat((frame1, frame2), 1)
        optimizer.zero_grad()
        outputs = model(inputs)
        
        target_resized = F.interpolate(target, size=outputs.shape[2:], mode='bilinear', align_corners=False)
        loss = criterion(outputs, target_resized)
        loss.backward()
        optimizer.step()
        
        
        epoch_loss += loss.item()
        
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item()}')


    avg_epoch_loss = epoch_loss / len(dataloader)
    print(f'\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss}', flush=True)


Loading dataset...
Dataset loaded.
Training model...

Epoch 1/1


Epoch 1:   1%|▍                                                                      | 1/149 [01:20<3:19:35, 80.91s/it]

Batch 0/149, Loss: 0.15214115381240845


Epoch 1:   7%|█████▏                                                                | 11/149 [13:14<2:38:48, 69.05s/it]

Batch 10/149, Loss: 0.010839411988854408


Epoch 1:  14%|█████████▊                                                            | 21/149 [24:00<2:17:33, 64.48s/it]

Batch 20/149, Loss: 0.006725357845425606


Epoch 1:  21%|██████████████▌                                                       | 31/149 [34:57<2:10:22, 66.30s/it]

Batch 30/149, Loss: 0.004736681468784809


Epoch 1:  28%|███████████████████▎                                                  | 41/149 [46:00<1:59:39, 66.48s/it]

Batch 40/149, Loss: 0.003098583547398448


Epoch 1:  34%|███████████████████████▉                                              | 51/149 [57:01<1:47:20, 65.71s/it]

Batch 50/149, Loss: 0.00221150740981102


Epoch 1:  41%|███████████████████████████▊                                        | 61/149 [1:07:19<1:37:35, 66.54s/it]

Batch 60/149, Loss: 0.0019531032303348184


Epoch 1:  48%|████████████████████████████████▍                                   | 71/149 [1:18:52<1:30:58, 69.99s/it]

Batch 70/149, Loss: 0.0014893243787810206


Epoch 1:  54%|████████████████████████████████████▉                               | 81/149 [1:30:19<1:17:58, 68.81s/it]

Batch 80/149, Loss: 0.0012645451352000237


Epoch 1:  61%|█████████████████████████████████████████▌                          | 91/149 [1:41:48<1:07:41, 70.03s/it]

Batch 90/149, Loss: 0.0012808674946427345


Epoch 1:  68%|██████████████████████████████████████████████▊                      | 101/149 [1:53:11<54:46, 68.47s/it]

Batch 100/149, Loss: 0.0010609830496832728


Epoch 1:  74%|███████████████████████████████████████████████████▍                 | 111/149 [2:04:54<45:42, 72.16s/it]

Batch 110/149, Loss: 0.0009559508762322366


Epoch 1:  81%|████████████████████████████████████████████████████████             | 121/149 [2:17:18<32:57, 70.64s/it]

Batch 120/149, Loss: 0.0007668482139706612


Epoch 1:  88%|████████████████████████████████████████████████████████████▋        | 131/149 [2:28:32<19:55, 66.42s/it]

Batch 130/149, Loss: 0.0007557718781754375


Epoch 1:  95%|█████████████████████████████████████████████████████████████████▎   | 141/149 [2:39:28<08:47, 65.95s/it]

Batch 140/149, Loss: 0.000819932552985847


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████| 149/149 [2:47:35<00:00, 67.48s/it]


Epoch 1/1, Average Loss: 0.004789141528683421





In [2]:
torch.save(model.state_dict(), 'frame_interpolation_model.pth')