In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F


class WindowClassifierWithTransformer(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64, num_classes=7,
                 num_layers=2, nhead=8, dropout=0.1, num_windows=9):
        super(WindowClassifierWithTransformer, self).__init__()

        self.embedding = nn.Linear(input_dim, hidden_dim)

        self.positional_encoding = nn.Parameter(torch.randn(1, num_windows, hidden_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)

        x = x + self.positional_encoding

        x = x.transpose(0, 1)

        x = self.transformer_encoder(x)

        x = x.transpose(0, 1)

        logits = self.classifier(x)

        probs = F.softmax(logits, dim=-1)
        return probs


# Example usage:
if __name__ == "__main__":
    # Create a sample input: batch_size=2, num_windows=9, input_dim=3
    sample_input = torch.randn(2, 9, 3)

    # Instantiate the classifier
    model = WindowClassifierWithTransformer()

    # Forward pass through the model
    output = model(sample_input)
    print(output.shape)  # Expected shape: (2, 9, 7)

torch.Size([2, 9, 7])


In [8]:
x = torch.randn(32, 9, 3)  # Example input

model = WindowClassifierWithTransformer()
output = model(x)

In [9]:
output.shape

torch.Size([32, 9, 7])

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Cross-Stitch Unit
class CrossStitchUnit(nn.Module):
    def __init__(self):
        super(CrossStitchUnit, self).__init__()
        # Learnable alpha parameters (initialized to identity)
        self.alpha = nn.Parameter(torch.tensor([[0.9, 0.1], [0.1, 0.9]], requires_grad=True))

    def forward(self, a, b):
        # a and b are features from each task branch
        a_out = self.alpha[0, 0] * a + self.alpha[0, 1] * b
        b_out = self.alpha[1, 0] * a + self.alpha[1, 1] * b
        return a_out, b_out

# A simple convolutional block
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        return self.pool(F.relu(self.conv(x)))

# Main Multi-Task Network with Cross-Stitch
class CrossStitchNet(nn.Module):
    def __init__(self):
        super(CrossStitchNet, self).__init__()
        # Task-specific initial blocks
        self.taskA_conv1 = ConvBlock()
        self.taskB_conv1 = ConvBlock()

        # Cross-stitch unit after first conv layer
        self.cross_stitch = CrossStitchUnit()

        # Shared second conv block
        self.taskA_conv2 = ConvBlock()
        self.taskB_conv2 = ConvBlock()

        # Task-specific heads
        self.taskA_fc = nn.Linear(16 * 7 * 7, 10)  # For classification
        self.taskB_fc = nn.Linear(16 * 7 * 7, 1)   # For regression

    def forward(self, x):
        a = self.taskA_conv1(x)
        b = self.taskB_conv1(x)

        # Cross-stitch blending
        a, b = self.cross_stitch(a, b)

        # Continue task-specific paths
        a = self.taskA_conv2(a)
        b = self.taskB_conv2(b)

        # Flatten
        a = a.view(a.size(0), -1)
        b = b.view(b.size(0), -1)

        # Final heads
        outA = self.taskA_fc(a)
        outB = self.taskB_fc(b)
        return outA, outB

In [None]:
x = torch.randn(32, 1, 28, 28)  # Example input

model = CrossStitchNet()
output = model(x)

In [12]:
output.shape

torch.Size([32, 9, 7])