# Introduction to Crossview

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class Encoder(nn.Module):
    def __init__(self, image_channels, hidden_dims, output_dims):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, hidden_dims, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dims, hidden_dims * 2, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(hidden_dims * 2 * (image_size // 4) * (image_size // 4), output_dims),
            nn.ReLU()
        )
        
    def forward(self, x):
        phi = self.encoder(x)
        return phi

class Decoder(nn.Module):
    def __init__(self, input_dims, hidden_dims, image_channels, image_size):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(input_dims, hidden_dims * 2 * (image_size // 4) * (image_size // 4)),
            nn.ReLU(),
            nn.Unflatten(1, (hidden_dims * 2, image_size // 4, image_size // 4)),
            nn.ConvTranspose2d(hidden_dims * 2, hidden_dims, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dims, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Assuming the input images are normalized to [0,1]
        )
        
    def forward(self, x):
        x = self.decoder(x)
        return x

# Example instantiation
image_channels = 3  # For RGB images
image_size = 64  # Assuming 64x64 images
hidden_dims = 64  # Hidden dimensions
output_dims = 100  # Size of phi, adjust as per your requirement

encoder = Encoder(image_channels, hidden_dims, output_dims)
decoder = Decoder(output_dims, hidden_dims, image_channels, image_size)


In [None]:
# Assuming you have a DataLoader for your dataset
for epoch in range(num_epochs):
    for img, _ in dataloader:
        # Normalize images to [0,1]
        img = img / 255.0
        
        phi = encoder(img)
        recon_img = decoder(phi)
        
        loss = ((img - recon_img)**2).mean()  # Mean Squared Error Loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
