In [None]:
#Morphology is the study of the way words are built up from smaller meaning bearing units.
#Study and understand the concepts of morphology by the use of add delete table.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]  # batch size
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.values(values).view(N, value_len, self.heads, self.head_dim)
        keys = self.keys(keys).view(N, key_len, self.heads, self.head_dim)
        queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-inf'))

        attention = torch.softmax(energy / math.sqrt(self.head_dim), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.embed_size)
        return self.fc_out(out)

class FeedForward(nn.Module):
    def __init__(self, embed_size, hidden_dim, dropout):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.fc2(self.dropout(F.relu(self.fc1(x))))

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, hidden_dim, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, hidden_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attn = self.attention(value, key, query, mask)
        x = self.norm1(attn + query)
        ff = self.ff(x)
        return self.norm2(ff + x)

class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=100):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

class Transformer(nn.Module):
    def __init__(self, embed_size, heads, hidden_dim, num_layers, vocab_size, max_length, dropout):
        super(Transformer, self).__init__()
        self.embed_size = embed_size
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, max_length)
        self.layers = nn.ModuleList(
            [TransformerBlock(embed_size, heads, hidden_dim, dropout) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        out = self.dropout(self.word_embedding(x) + self.positional_encoding(x))
        for layer in self.layers:
            out = layer(out, out, out, mask)
        return self.fc_out(out)

def add_delete_table(word, add_affix, delete_affix):
    added_word = word + add_affix
    deleted_word = word[:-len(delete_affix)] if word.endswith(delete_affix) else word

    print(f"Base Word: {word}")
    print(f"After Adding '{add_affix}': {added_word}")
    print(f"After Deleting '{delete_affix}': {deleted_word}")
    print("-")

# Example usage
add_delete_table("happy", "ness", "y")
add_delete_table("run", "er", "n")
add_delete_table("act", "ion", "t")


Base Word: happy
After Adding 'ness': happyness
After Deleting 'y': happ
-
Base Word: run
After Adding 'er': runer
After Deleting 'n': ru
-
Base Word: act
After Adding 'ion': action
After Deleting 't': ac
-
