In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageEncoder(nn.Module):
    def __init__(self, feature_dim=256):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(128, feature_dim)

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class SeqDecoder(nn.Module):
    def __init__(self, feature_dim, num_actions, seq_len=10, emb_dim=128):
        super().__init__()
        self.seq_len = seq_len
        self.action_emb = nn.Embedding(num_actions + 2, emb_dim) # +2 for start and end tokens
        self.gru = nn.GRU(emb_dim, feature_dim, batch_first=True)
        self.fc = nn.Linear(feature_dim, num_actions + 2)

    def forward(self, features, action_seq=None, teacher_forcing_ratio=0.5):
        batch_size = features.size(0)
        input_token = torch.full((batch_size, 1), 0, dtype=torch.long, device=features.device) # start token idx = 0
        outputs = []
        hidden = features.unsqueeze(0) # (1, B, F)
        for t in range(self.seq_len):
            emb = self.action_emb(input_token) # (B, 1, emb_dim)
            out, hidden = self.gru(emb, hidden)
            logits = self.fc(out.squeeze(1)) # (B, num_actions+2)
            outputs.append(logits)
            if action_seq is not None and torch.rand(1).item() < teacher_forcing_ratio:
                input_token = action_seq[:, t].unsqueeze(1)
            else:
                input_token = logits.argmax(dim=-1, keepdim=True)
        return torch.stack(outputs, dim=1) # (B, seq_len, num_actions+2)

class RotSeqModel(nn.Module):
    def __init__(self, num_actions=4, seq_len=10, feature_dim=256):
        super().__init__()
        self.encoder = ImageEncoder(feature_dim)
        self.num_actions = num_actions
        self.seq_len = seq_len
        self.decoder = SeqDecoder(feature_dim * 2, num_actions, seq_len)

    def forward(self, img_start, img_end, action_seq=None, teacher_forcing_ratio=0.5):
        feat1 = self.encoder(img_start)
        feat2 = self.encoder(img_end)
        features = torch.cat([feat1, feat2], dim=-1)
        return self.decoder(features, action_seq, teacher_forcing_ratio)


Predicting the sequence of transformations of a picture to obtain one from another

In [2]:
import torch.nn.functional as F

def sequence_loss(pred_logits, target_seq, ignore_index=-100):
    """
    pred_logits: (B, seq_len, num_tokens)
    target_seq: (B, seq_len)
    """
    loss = F.cross_entropy(
        pred_logits.view(-1, pred_logits.size(-1)),
        target_seq.view(-1),
        ignore_index=ignore_index
    )
    return loss


In [None]:
# img_start, img_end: (B, 3, H, W) tensors
# target_seq: (B, seq_len) LongTensor with tokens
model = RotSeqModel(num_actions=4, seq_len=6)
logits = model(img_start, img_end, action_seq=target_seq)
loss = sequence_loss(logits, target_seq)
loss.backward()