<a href="https://colab.research.google.com/github/KeneKing12/Kenechukwu/blob/main/DETR_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

class PositionalEncoding2D(nn.Module):
    def __init__(self, d_model, height, width):
        super().__init__()
        if d_model % 4 != 0:
            raise ValueError("d_model must be divisible by 4")

        self.height = height
        self.width = width
        self.d_model = d_model

        pe = torch.zeros(d_model, height, width)
        d_model = int(d_model / 2)
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))
        pos_w = torch.arange(0., width).unsqueeze(1)
        pos_h = torch.arange(0., height).unsqueeze(1)

        pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        pe[d_model+1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :, :x.size(2), :x.size(3)]

class DETROnlyModel(nn.Module):
    def __init__(self, in_channels=11, d_model=256, nhead=8, num_layers=6, ff_dim=2048, height=16, width=16):
        super().__init__()
        self.input_proj = nn.Conv2d(in_channels, d_model, kernel_size=1)
        self.pos_enc = PositionalEncoding2D(d_model, height, width)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff_dim, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_proj = nn.Conv2d(d_model, 1, kernel_size=1)

    def forward(self, x):
        x = self.input_proj(x)  # [B, d_model, H, W]
        x = self.pos_enc(x)

        B, C, H, W = x.shape
        x = x.flatten(2).permute(0, 2, 1)  # [B, HW, C]
        x = self.transformer(x)
        x = x.permute(0, 2, 1).reshape(B, C, H, W)

        x = self.output_proj(x)
        return torch.sigmoid(x)

# Accuracy Function
def compute_accuracy(preds, masks, threshold=0.5):
    preds = (preds > threshold).float()
    correct = (preds == masks).float()
    return correct.sum() / correct.numel()

# Example usage:
# model = DETROnlyModel().to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# criterion = nn.BCELoss()
# ... training loop similar to previous
