<a href="https://colab.research.google.com/github/aliborji/A-Collection-of-important-tasks-in-pytorch/blob/master/mnist_seq2seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
import matplotlib.pyplot as plt
%matplotlib inline
ROOT = './'

In [None]:
def get_data(batch_size=64):
    class Dataset:
        def __init__(self, images, targets):
            self.images = images
            self.targets = targets

        def __getitem__(self, idx):
            return self.images[idx], self.targets[idx]

        def __len__(self):
            return self.images.shape[0]

    def split(data, valid_pct=0.3):
        cutoff = int((1 - valid_pct) * data.shape[0])
        train = data[:cutoff]
        valid = data[cutoff:]
        return train, valid

    mnist = MNIST(ROOT, download=True)
    train_images, valid_images = split(mnist.data.float() / 255)
    train_labels, valid_labels = split(mnist.targets)
    train_ds = Dataset(train_images, train_labels)
    valid_ds = Dataset(valid_images, valid_labels)
    num_workers = cpu_count()   
    train_dl = DataLoader(train_ds,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=num_workers,
                          pin_memory=True)
    valid_dl = DataLoader(valid_ds,
                          batch_size=batch_size*2,
                          shuffle=False,
                          num_workers=num_workers,
                          pin_memory=True)
    
    return train_dl, valid_dl

In [None]:
class Seq2Seq(nn.Module):

    def __init__(self,
                 dim,            # dimension of attention layers
                 src_vocab_size, # used to determine token embeddings
                 tgt_vocab_size, # ...
                 src_max_len,    # used to determine positional embeddings
                 tgt_max_len,    # ...
                 device,         # the device hosting the computation
                 num_heads=1,    # number of heads per attention layer (divides into dim)
                 num_layers=1,   # number of encoder & decoder layers
                 expand=1,       # the factor by which to expand dim for internal attention layer
                 p=0.1):         # dropout probability
        super().__init__()
        self.encoder = Encoder(dim, src_vocab_size, src_max_len, num_heads, num_layers, expand, p)
        self.decoder = Decoder(dim, tgt_vocab_size, tgt_max_len, num_heads, num_layers, expand, p)
        self.device = device

        
    def __len__(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


    def make_src_mask(self, src):        
        # src = (N, len)
        # src_mask = (N, 1, 1, len)
        src_mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(src.device)

    
    def make_tgt_mask(self, tgt):     
        # tgt = (N, len)
        
        # tgt_sub_mask = (len, len)
        tgt_len = tgt.shape[1]
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len)))
        tgt_sub_mask = tgt_sub_mask.bool().to(tgt.device)

        return tgt_sub_mask

    
    def forward(self, src, tgt):
        src_mask = None
        tgt_mask = self.make_tgt_mask(tgt)

        enc_out = self.encoder(src, src_mask)
        dec_out = self.decoder(tgt, enc_out, src_mask, tgt_mask)
        return F.linear(dec_out, self.decoder.embedding.tok_embedding.weight)



class Encoder(nn.Module):

    def __init__(self, dim, vocab_size, max_len, num_heads, num_layers, expand, p):
        super().__init__()
        self.embedding = PosEmbedding(dim, max_len, p)
        self.layers = nn.ModuleList(
            [TransformerBlock(dim, num_heads, expand, p)
             for _ in range(num_layers)]
        )

    
    def forward(self, src, mask=None):
        # src = (N, seq_len)
        out = self.embedding(src)
        for layer in self.layers:
            out = layer(out, out, out, mask)       
        # out = (N, seq_len, dim)
        return out



class Decoder(nn.Module):

    def __init__(self, dim, vocab_size, max_len, num_heads, num_layers, expand, p):
        super().__init__()
        self.embedding = Embedding(dim, vocab_size, max_len, p)
        self.layers = nn.ModuleList(
            [DecoderBlock(dim, num_heads, expand, p) for _ in range(num_layers)]
        )

    
    def forward(self, tgt, enc_out, src_mask, tgt_mask):
        # tgt = (N, len)
        # enc_out = (N, len, dim)
        
        # tgt_out = (N, len, dim)
        tgt_out = self.embedding(tgt)
        for layer in self.layers:
            tgt_out = layer(tgt_out, enc_out, enc_out, src_mask, tgt_mask)
        
        return tgt_out



class DecoderBlock(nn.Module):
    
    def __init__(self, dim, num_heads, expand, p):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attention = MultiHeadAttention(dim, num_heads, p)
        self.transformer_block = TransformerBlock(dim, num_heads, expand, p)
        self.dropout = nn.Dropout(p)

    
    def forward(self, x, key, value, src_mask=None, tgt_mask=None):
        attention = self.attention(x, x, x, tgt_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(query, key, value, src_mask)
        return out



class TransformerBlock(nn.Module):

    def __init__(self, dim, num_heads, expand, p):
        super().__init__()
        self.attention = MultiHeadAttention(dim, num_heads, p)
        self.ff = nn.Sequential(
            nn.Linear(dim, expand*dim),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(expand*dim, dim)
        )
        self.attn_norm = nn.LayerNorm(dim)
        self.ff_norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(p)
    
    
    def forward(self, query, key, value, mask=None):
        # x = (N, len, dim)
        # mask = (N, len)

        # attn_out = (N, len, dim)
        attn_out = self.attention(query, key, value, mask)
        attn_out = self.dropout(attn_out)
        attn_out = self.attn_norm(query + attn_out)

        # ff_out = (N, len, dim)
        ff_out = self.ff(attn_out)
        ff_out = self.dropout(ff_out)
        ff_out = self.ff_norm(attn_out + ff_out)

        return ff_out



class MultiHeadAttention(nn.Module):

    def __init__(self, dim, num_heads, p):
        assert dim % num_heads == 0
        super().__init__()

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        scale = torch.rsqrt(torch.tensor(dim).float())
        self.scale = nn.Parameter(scale, requires_grad=False)

        self.Q = nn.Linear(dim, dim)
        self.K = nn.Linear(dim, dim)
        self.V = nn.Linear(dim, dim)
        self.fc_out = nn.Linear(dim, dim)

        self.dropout = nn.Dropout(p)


    def forward(self, query, key, value, mask=None):
        # q,k,v: (N, len, dim)
        N = query.shape[0]
        query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1]

        query = self.Q(query)
        key = self.K(key)
        value = self.V(value)

        # q,k,v: (N, len, num_heads, head_dim)
        query = query.reshape(N, query_len, self.num_heads, self.head_dim)
        key = key.reshape(N, key_len, self.num_heads, self.head_dim)
        value = value.reshape(N, value_len, self.num_heads, self.head_dim)

        # energy: (N, num_heads, query_len, key_len)
        energy = torch.einsum('nqhd,nkhd->nhqk', [query, key]) * self.scale # really dividing, but used rsqrt

        if mask is not None:
            energy.masked_fill_(mask == 0, float('-inf'))

        # attention: (N, heads, query_len, key_len)
        attention = torch.softmax(energy, dim=-1)
        attention = self.dropout(attention)

        # out: (N, query_len, num_heads, head_dim)
        out = torch.einsum('nhql,nlhd->nqhd', [attention, value])
        
        # out: (N, query_len, dim)
        out = out.reshape(N, query_len, self.dim)
        out = self.fc_out(out)
        return out



class PosEmbedding(nn.Module):

    def __init__(self, dim, seq_len, p):
        super().__init__()
        self.pos_embedding = nn.Embedding(seq_len, dim)
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(p)
        self.pos = torch.arange(0, seq_len).unsqueeze(0)
        scale = torch.sqrt(torch.tensor(dim).float())
        self.scale = nn.Parameter(scale, requires_grad=False)


    def forward(self, seq):
        N, seq_len, seq_dim = seq.shape
        pos = self.pos.expand(N, seq_len).to(seq.device)
        embeds = seq * self.scale + self.pos_embedding(pos)
        return self.dropout(self.norm(embeds))


class Embedding(nn.Module):

    def __init__(self, dim, vocab_size, max_len, p):
        super().__init__()
        self.tok_embedding = nn.Embedding(vocab_size, dim)
        self.pos_embedding = nn.Embedding(max_len, dim)
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(p)
        scale = torch.sqrt(torch.tensor(dim).float())
        self.scale = nn.Parameter(scale, requires_grad=False)


    def forward(self, seq):
        # seq = (N, len)
        N, seq_len = seq.shape
        pos = torch.arange(0, seq_len).unsqueeze(0).expand(N, seq_len).to(seq.device)

        # embeds = (N, len, dim)
        embeds = self.tok_embedding(seq) * self.scale + self.pos_embedding(pos)
        return self.dropout(self.norm(embeds))

In [None]:
def get_loss(outs, tgt):
    return F.cross_entropy(
        outs.reshape(-1, outs.shape[-1]),
        tgt[:,1:].reshape(-1))
    
def get_accuracy(outs, tgt):
    tgt = tgt[:,1]
    outs = outs[:,0].argmax(-1)
    return (tgt == outs).float().mean()

In [None]:
batch_size = 512
epochs = 50
dim = 28
src_vocab_size = None
tgt_vocab_size = 12 # ten digits + sos + eos
src_max_len = 28
tgt_max_len = 3 # [sos, digit, eos]
num_heads = 2
num_layers = 2
expand = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dl, valid_dl = get_data(batch_size)
model = Seq2Seq(dim, 
                src_vocab_size, 
                tgt_vocab_size,
                src_max_len,
                tgt_max_len,
                device,
                num_heads,
                num_layers,
                expand).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

print(f'Model has {len(model)} parameters.')

Model has 34244 parameters.


In [None]:
sos = torch.tensor(10) # start token
eos = torch.tensor(11) # end token

for epoch in range(epochs):
    losses, accs = [], []

    model.train()
    for images, labels in train_dl:
        images = images.to(device)
        _sos = sos.expand(labels.shape[0])
        _eos = eos.expand(labels.shape[0])
        labels = torch.stack([_sos, labels, _eos]).t().to(device)
        outs = model(images, labels[:,:-1])
        loss = get_loss(outs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    model.eval()
    with torch.no_grad():
        for images, labels in valid_dl:
            images = images.to(device)
            _sos = sos.expand(labels.shape[0])
            _eos = eos.expand(labels.shape[0])
            labels = torch.stack([_sos, labels, _eos]).t().to(device)
            outs = model(images, labels[:,:-1])
            loss = get_loss(outs, labels)
            accuracy = get_accuracy(outs, labels)
            losses.append(loss)
            accs.append(accuracy)

    loss = torch.stack(losses).mean().item()
    acc = torch.stack(accs).mean().item()

    print(f'Epoch: {epoch+1:>2}\tLoss: {loss:.3f}\tAccuracy: {acc:.3f}')

Epoch:  1	Loss: 1.120	Accuracy: 0.211
Epoch:  2	Loss: 0.920	Accuracy: 0.339
Epoch:  3	Loss: 0.617	Accuracy: 0.566
Epoch:  4	Loss: 0.376	Accuracy: 0.758
Epoch:  5	Loss: 0.260	Accuracy: 0.839
Epoch:  6	Loss: 0.220	Accuracy: 0.862
Epoch:  7	Loss: 0.195	Accuracy: 0.880
Epoch:  8	Loss: 0.172	Accuracy: 0.893
Epoch:  9	Loss: 0.159	Accuracy: 0.906
Epoch: 10	Loss: 0.136	Accuracy: 0.917
Epoch: 11	Loss: 0.129	Accuracy: 0.922
Epoch: 12	Loss: 0.116	Accuracy: 0.930
Epoch: 13	Loss: 0.106	Accuracy: 0.936
Epoch: 14	Loss: 0.103	Accuracy: 0.938
Epoch: 15	Loss: 0.095	Accuracy: 0.943
Epoch: 16	Loss: 0.090	Accuracy: 0.946
Epoch: 17	Loss: 0.086	Accuracy: 0.947
Epoch: 18	Loss: 0.085	Accuracy: 0.948
Epoch: 19	Loss: 0.082	Accuracy: 0.951
Epoch: 20	Loss: 0.080	Accuracy: 0.951
Epoch: 21	Loss: 0.077	Accuracy: 0.953
Epoch: 22	Loss: 0.074	Accuracy: 0.955
Epoch: 23	Loss: 0.071	Accuracy: 0.957
Epoch: 24	Loss: 0.069	Accuracy: 0.958
Epoch: 25	Loss: 0.066	Accuracy: 0.960
Epoch: 26	Loss: 0.067	Accuracy: 0.959
Epoch: 27	Lo