In [1]:
import torch, torchvision
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import torchvision.utils as vutils
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

In [2]:
class OurAttentionLayer(nn.Module):
    def __init__(self, patch_size, channels_in, hidden_dim, emb_size=1, cross=False):
        super().__init__()
        self.cross = cross
        self.hidden_dim = hidden_dim #N
        self.channels_in = channels_in #C
        self.patch_size = patch_size #pq
        self.emb_size = emb_size #Z
        self.Wk = nn.Linear(patch_size, hidden_dim)     # [pq, N]
        self.Wv = nn.Linear(patch_size, hidden_dim)     # [pq, N]
        self.LN = nn.LayerNorm([channels_in, patch_size])
        if cross:
            self.Wi = nn.Linear(emb_size, channels_in)  # [Z, C]
            self.Wj = nn.Linear(emb_size, hidden_dim)   # [Z, N]
            self.Wq = nn.Linear(hidden_dim, hidden_dim) # [N, N]
        else:
            self.Wq = nn.Linear(patch_size, hidden_dim) # [C, N]
        self.Wr = nn.Linear(hidden_dim, patch_size)
        self.softmax = nn.Softmax(dim=-1)
        self.dscale = 1/(hidden_dim**0.5)
    def forward(self, image, text = None, ret_attn_QKV=False):
        # image == [Batch, channels, patch_size] == [..., C, pq]
        K = self.Wk(image) # [..., C, pq] * [pq, N] = [..., C, N]
        V = self.Wv(image)
        if self.cross and text is None:
            text = torch.rand(1, self.emb_size)
        if self.cross:
            #text_T = torch.permute(text, (-1, -2))
            # text = [Batch, seq_len, emb_size] == [..., S, Z]
            I = self.Wi(text) # [..., S, Z] * [Z, C] -> [..., S, C]
            J = self.Wj(text) # [..., S, Z] * [Z, N] -> [..., S, N]
            I_T = torch.transpose(I, -1, -2) # [..., C, S]
            Q1 = torch.einsum("...cs,...sn->...cn", I_T, J) # возможно надо отдебажить учитывая Batch и прочее
            # [..., C, S] * [..., S, N] -> C, N
            Q = self.Wq(Q1).unsqueeze(1).expand_as(K) # -> C, N
        else:
            Q = self.Wq(image) # [..., C, pq] * [pq, N] = [..., C, N]

        qk = torch.einsum("...jn,...cn->...cj", Q, K)
        R = self.softmax(qk*self.dscale)
        R = torch.einsum("...ic,...cn->...in", R, V) # Scaled Dot-Product Attention
        O = self.Wr(R) # [..., C, N] * [N, pq] -> [..., C, pq]
        O = O + image
        O = self.LN(O)
        if ret_attn_QKV:
            return O, Q, K, V
        return O

class PatchImage(nn.Module):
    def __init__(self, patch_size, reverse=False):
        super().__init__()
        self.patch_size = patch_size
        self.n = int(self.patch_size**(0.5))
        assert self.n**2 == patch_size, "patch_size must be full square"
        self.reverse = reverse
    def forward(self, x):

        n = self.n
        if self.reverse:
            b, c, h, w, s = x.shape
            x = torch.reshape(x, (b, c, h, w, n, n))
            x = torch.transpose(x, -2, -3)
            x = torch.reshape(x, (b, c, h*n, w*n))
            return x
        b, c, h, w = x.shape
        x = torch.reshape(x, (b, c, h//n, n, w//n, n))
        x = torch.transpose(x, -2, -3)
        x = torch.reshape(x, (b, c, h//n, w//n, n*n))
        return x
        #torch.reshape(torch.transpose(torch.reshape(a, (b, c, h//n, n, w//n, n)), -2, -3), (b, c, h//n, w//n, n*n))