In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

class Backbone(nn.Module):
    def __init__(self):
        pass
    def foward(self):
        pos_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)  # positional embeddings

class Transformer(nn.Module):
    def __init__(self, num_layers=6, d_model=256, d_ff=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, dim_feedforward=d_ff, nhead=1, dropout=dropout),
            num_layers=num_layers)
        self.decoder = nn.TransformerDecoder(
            nn.TransformerEncoderLayer(d_model=d_model, dim_feedforward=d_ff, nhead=2, dropout=dropout),
            num_layers=num_layers)

    def forward(self, src, query):
        key_value = self.encoder(src)
        output = self.decoder(query, key_value)
        return output

class PanopticHead(nn.Module):
    def __init__(self, feature_dim, num_classes, mask_dim):
        super(PanopticHead, self).__init__()
        self.classifier = nn.Linear(feature_dim, num_classes)
        self.mask_predictor = nn.Sequential(
            nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(feature_dim, mask_dim, kernel_size=1)
        )

    def forward(self, features):
        classes = self.classifier(features)
        masks = self.mask_predictor(features.unsqueeze(-1).unsqueeze(-1))
        return classes, masks

class DETR(nn.Module):
    def __init__(self, num_classes, num_queries, hidden_dim=256, num_heads=8, num_encoder_layers=6, num_decoder_layers=6):
        super(DETR, self).__init__()
        self.transformer = Transformer(num_encoder_layers, hidden_dim, num_heads, hidden_dim * 4)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.panoptic_head = PanopticHead(hidden_dim, num_classes, mask_dim=1)

    def forward(self, src, mask):
        # src: (batch_size, channel, height, width)
        bs, c, h, w = src.size()
        src = src.flatten(2).permute(2, 0, 1)  # (hw, bs, c)
        pos_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)  # positional embeddings
        tgt = torch.zeros_like(pos_embed)
        hs = self.transformer(src, tgt + pos_embed)
        class_logits, masks = self.panoptic_head(hs[-1])
        return class_logits, masks.view(bs, h, w)