In [None]:
"""
BIDIRECTIONAL TRANSFORMER WITH RELATIVE POSITIONING LAYER + CUSTOM ATTENTION LAYER
"""

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.models import vgg16, VGG16_Weights
from torch.nn.functional import interpolate

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
channels = 3
embed_dim = 256
num_heads = 4
hidden_dim = 512
num_layers = 2
dropout = 0.1
N = 30 # number of images of the initial sequence
K = 20 # number of images of the output sequences (predicted images)
batch_size = 16
num_epochs = 1000
learning_rate = 0.001

# Load data
data = np.load("Modigliani_paintings.npy")
#data = np.load("Academic_Art_paintings.npy")
num_samples, image_size, _, channels_check = data.shape
assert channels == channels_check, f"Expected {channels} channels, got {channels_check}"
data = np.transpose(data, (0, 3, 1, 2))
image_dim = channels * image_size * image_size

# Create sequences
def create_sequences(data, N, K):
    total_length = N + K
    num_sequences = num_samples - total_length + 1
    sequences = []
    targets = []
    for i in range(num_sequences):
        seq = data[i:i + N]
        tgt = data[i + N:i + N + K]
        sequences.append(seq)
        targets.append(tgt)
    return np.array(sequences), np.array(targets)

train_data, train_targets = create_sequences(data, N, K)
train_data = torch.tensor(train_data, dtype=torch.float32)
train_targets = torch.tensor(train_targets, dtype=torch.float32)
train_data_flat = train_data.view(-1, N, image_dim)
train_targets_flat = train_targets.view(-1, K, image_dim)

# DataLoader
train_dataset = torch.utils.data.TensorDataset(train_data, train_targets)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Custom Multi-Head Attention with Relative Positional Encoding (Fixed)
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.max_len = max(N, K)
        self.rel_pos_bias = nn.Parameter(torch.randn(self.max_len * 2 - 1, num_heads))

    def forward(self, query, key, value, mask=None):
        batch_size, q_len, _ = query.size()
        k_len = key.size(1)
        
        Q = self.q_proj(query).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(batch_size, k_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(batch_size, k_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (batch, heads, q_len, k_len)
        
        # Fixed: Compute relative positional bias with correct q_len and k_len
        positions_q = torch.arange(q_len, device=device)
        positions_k = torch.arange(k_len, device=device)
        rel_dists = positions_q[:, None] - positions_k[None, :]  # (q_len, k_len)
        offset = torch.tensor(self.max_len - 1, device=device)
        indices = rel_dists + offset  # Shape: (q_len, k_len)
        indices = torch.clamp(indices, 0, 2 * self.max_len - 2)  # Clamp to bounds
        rel_pos = self.rel_pos_bias[indices]  # (q_len, k_len, num_heads)
        # Ensure rel_pos matches scores shape
        scores = scores + rel_pos.permute(2, 0, 1).unsqueeze(0)  # (1, num_heads, q_len, k_len)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(batch_size, q_len, self.embed_dim)
        return self.out_proj(out)

# Custom Transformer Encoder Layer
class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, nhead, dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, src):
        src2 = self.self_attn(src, src, src)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

# Custom Transformer Decoder Layer
class CustomTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, nhead, dropout)
        self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, tgt, memory, tgt_mask=None):
        tgt2 = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

# Encoder-Decoder with CNNs
class EncoderTransformerDecoder(nn.Module):
    def __init__(self, image_size, channels, embed_dim, num_heads, hidden_dim, num_layers, dropout, N, K):
        super().__init__()
        
        self.image_size = image_size
        self.channels = channels
        
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, embed_dim, kernel_size=3, padding=1)
        )
        self.encoder_layers = nn.ModuleList([
            CustomTransformerEncoderLayer(embed_dim, num_heads, hidden_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.decoder_cnn_input = nn.Linear(embed_dim, embed_dim * image_size * image_size // 16)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, channels, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        self.decoder_layers = nn.ModuleList([
            CustomTransformerDecoderLayer(embed_dim, num_heads, hidden_dim, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, src, tgt):
        batch_size = src.size(0)
        src = src.view(batch_size * N, self.channels, self.image_size, self.image_size)
        src_emb = self.encoder_cnn(src)
        src_emb = src_emb.mean(dim=(2, 3))
        src_emb = src_emb.view(batch_size, N, embed_dim)
        memory = src_emb
        for layer in self.encoder_layers:
            memory = layer(memory)
        
        batch_size = tgt.size(0)
        tgt = tgt.view(batch_size * K, self.channels, self.image_size, self.image_size)
        tgt_emb = self.encoder_cnn(tgt)
        tgt_emb = tgt_emb.mean(dim=(2, 3))
        tgt_emb = tgt_emb.view(batch_size, K, embed_dim)
        output = tgt_emb
        for layer in self.decoder_layers:
            output = layer(output, memory, tgt_mask=None)
        output = self.decoder_cnn_input(output)
        output = output.view(batch_size * K, embed_dim, self.image_size // 4, self.image_size // 4)
        output = self.decoder_cnn(output)
        output = output.view(batch_size, K, self.channels, self.image_size, self.image_size)
        return output

# Perceptual Loss using VGG
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg16(weights=VGG16_Weights.DEFAULT).features[:16].eval().to(device)
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.mse = nn.MSELoss()

    def forward(self, pred, target):
        pred_vgg = self.vgg(pred)
        target_vgg = self.vgg(target)
        return self.mse(pred_vgg, target_vgg) + self.mse(pred, target)

# Training Function
def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch_idx, (src, tgt) in enumerate(train_loader):
            src, tgt = src.to(device), tgt.to(device)
            
            optimizer.zero_grad()
            output = model(src, tgt)
            loss = criterion(output.view(-1, channels, image_size, image_size), 
                            tgt.view(-1, channels, image_size, image_size))
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        print(f"\rEpoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}", end=" ")
    
    return losses

# Generate K images
def generate_images(model, input_seq, K, image_size, channels):
    model.eval()
    input_seq = input_seq.to(device)
    generated = []
    tgt = input_seq[:, -1:, :, :, :]
    
    with torch.no_grad():
        batch_size = input_seq.size(0)
        src = input_seq.view(batch_size * N, channels, image_size, image_size)
        src_emb = model.encoder_cnn(src)
        src_emb = src_emb.mean(dim=(2, 3))
        src_emb = src_emb.view(batch_size, N, embed_dim)
        memory = src_emb
        for layer in model.encoder_layers:
            memory = layer(memory)
        
        for _ in range(K):
            tgt_mask = None
            tgt_flat = tgt.view(-1, channels, image_size, image_size)
            tgt_emb = model.encoder_cnn(tgt_flat)
            tgt_emb = tgt_emb.mean(dim=(2, 3))
            tgt_emb = tgt_emb.view(batch_size, tgt.size(1), embed_dim)
            output = tgt_emb
            for layer in model.decoder_layers:
                output = layer(output, memory, tgt_mask=None)
            output = model.decoder_cnn_input(output)
            output = output.view(batch_size * tgt.size(1), embed_dim, image_size // 4, image_size // 4)
            next_image = model.decoder_cnn(output)
            next_image = next_image.view(batch_size, tgt.size(1), channels, image_size, image_size)
            generated.append(next_image[:, -1:, :, :, :])
            tgt = torch.cat([tgt, generated[-1]], dim=1)
    
    generated = torch.cat(generated, dim=1)
    return generated

# Plot images
def plot_images(original, generated, N, K, image_size, channels):
    original = original.cpu().numpy()
    generated = generated.squeeze(0).cpu().numpy()
    
    plt.figure(figsize=(35, 10))
    for i in range(N):
        plt.subplot(1, N + K, i + 1)
        img = np.transpose(original[i], (1, 2, 0))
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.title(f"Original {i+1}")
        plt.axis('off')

    plt.figure(figsize=(35, 10))
    for i in range(K):
        plt.subplot(1, N + K, N + i + 1)
        img = np.transpose(generated[i], (1, 2, 0))
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.title(f"Generated {i+1}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Main Execution
if __name__ == "__main__":
    model = EncoderTransformerDecoder(
        image_size=image_size,
        channels=channels,
        embed_dim=embed_dim,
        num_heads=num_heads,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        dropout=dropout,
        N=N,
        K=K
    ).to(device)
    
    criterion = PerceptualLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    losses = train_model(model, train_loader, criterion, optimizer, num_epochs)
    
    plt.plot(range(1, num_epochs + 1), losses, label="Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Time")
    plt.legend()
    plt.show()
    
    test_idx = np.random.randint(0, len(train_data))
    test_input = train_data[test_idx:test_idx+1]
    generated_images = generate_images(model, test_input, K, image_size, channels)
    plot_images(train_data[test_idx], generated_images, N, K, image_size, channels)
    
    random_indices = np.random.choice(num_samples, size=N, replace=False)
    test_input = torch.tensor(data[random_indices], dtype=torch.float32).unsqueeze(0)
    generated_images = generate_images(model, test_input, K, image_size, channels)
    plot_images(test_input[0], generated_images, N, K, image_size, channels)

Epoch [32/10000], Loss: 4.1452 

In [None]:
"""

GENERATE A SEQUENCE of K IMAGES FROM A GIVEN INPUT SEQUENCE of N images 
"""

# Generate K images
def generate_images(model, input_seq, K, image_size, channels):
    model.eval()
    input_seq = input_seq.to(device)
    generated = []
    tgt = input_seq[:, -1:, :, :, :]
    
    with torch.no_grad():
        batch_size = input_seq.size(0)
        src = input_seq.view(batch_size * N, channels, image_size, image_size)
        src_emb = model.encoder_cnn(src)
        src_emb = src_emb.mean(dim=(2, 3))
        src_emb = src_emb.view(batch_size, N, embed_dim)
        memory = src_emb
        for layer in model.encoder_layers:
            memory = layer(memory)
        
        for _ in range(K):
            tgt_mask = None
            tgt_flat = tgt.view(-1, channels, image_size, image_size)
            tgt_emb = model.encoder_cnn(tgt_flat)
            tgt_emb = tgt_emb.mean(dim=(2, 3))
            tgt_emb = tgt_emb.view(batch_size, tgt.size(1), embed_dim)
            output = tgt_emb
            for layer in model.decoder_layers:
                output = layer(output, memory, tgt_mask=None)
            output = model.decoder_cnn_input(output)
            output = output.view(batch_size * tgt.size(1), embed_dim, image_size // 4, image_size // 4)
            next_image = model.decoder_cnn(output)
            next_image = next_image.view(batch_size, tgt.size(1), channels, image_size, image_size)
            generated.append(next_image[:, -1:, :, :, :])
            tgt = torch.cat([tgt, generated[-1]], dim=1)
    
    generated = torch.cat(generated, dim=1)
    return generated

# Plot images
def plot_images(original, generated, N, K, image_size, channels):
    original = original.cpu().numpy()
    generated = generated.squeeze(0).cpu().numpy()
    
    plt.figure(figsize=(35, 10))
    for i in range(N):
        plt.subplot(1, N + K, i + 1)
        img = np.transpose(original[i], (1, 2, 0))
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.title(f"Original {i+1}")
        plt.axis('off')

    plt.figure(figsize=(35, 10))
    for i in range(K):
        plt.subplot(1, N + K, N + i + 1)
        img = np.transpose(generated[i], (1, 2, 0))
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.title(f"Generated {i+1}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    

# main 

K=6 # GENERATED SEQUENCE of K IMAGES 
N=6 #FROM A GIVEN INPUT SEQUENCE of N images 
random_indices = np.random.choice(num_samples, size=N, replace=False)
test_input = torch.tensor(data[random_indices], dtype=torch.float32).unsqueeze(0)
generated_images = generate_images(model, test_input, K, image_size, channels)
plot_images(test_input[0], generated_images, N, K, image_size, channels)