In [1]:
# -------------------------------------------------------------------------------------------------------------
# Imports
# -------------------------------------------------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as functional
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

In [2]:
# -------------------------------------------------------------------------------------------------------------
# System path for imports
# -------------------------------------------------------------------------------------------------------------
PROJECT_ROOT='./'
import sys
sys.path.append(PROJECT_ROOT)

## Experiment 1: CNN + Transformer Hybrid Architecture

Contains the hybrid architecture with element wise product of features

### CNN-ViT Hybrid Model

In [None]:
# -------------------------------------------------------------------------------------------------------------
# Training Loop
# -------------------------------------------------------------------------------------------------------------

def train(model, dataloader, optimizer, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for blurry, sharp in dataloader:
            blurry = blurry.to(device)
            sharp = sharp.to(device)

            optimizer.zero_grad()
            output = model(blurry)
            loss = criterion(output, sharp)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        
        avg = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg:.4f}")






In [None]:
# -------------------------------------------------------------------------------------------------------------
# Driver Code
# -------------------------------------------------------------------------------------------------------------
from architecture.cnn_vit_hybrid_architecture import CNN_VIT_HYBRID_ARCHITECTURE

transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
raw_image = Image.open("./Blur.png").convert('RGB')
image = transform(raw_image)


model = CNN_VIT_HYBRID_ARCHITECTURE()
model(image)