<a href="https://colab.research.google.com/github/JDS289/-DNNs/blob/main/renameF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Learning and Neural Networks Second Assignment 2024
### Ferenc Huszár and Nic Lane
#### due date: Friday, 21 March 2025, 12:00 PM

##Information

Welcome to the second assignment for the Deep Neural Networks module. In this assignment you will explore some of the model architectures we talked about in the second half of lectures (ConvNets, Transformers) and you will also implement a fun model called MLP Mixer.

There are 70 marks given in total for this second assessment ($70\%$ of the total of 100 marks for the course), broken into three sections:
* (D) 10 marks for the ConvNet exercise
* (E) 20 marks for the Transformer/MLP Mixer exercise
* (F) 40 marks for a mini-project of your choice


#### Mini-projects

Mini-project tasks are a more exploratory and open-ended, giving you an opportunity to decide which aspect you'd like to focus on. The idea is to introduce you to the form of assess\ment typical in our Part III/MPhil modules. Mini-projects come with instructions, to indicate the depth of work we expect for certain marks, but you should feel free to deviate from instructions if you have a better idea to explore within the context. There are two options, one with a more theory/maths focus, one with more of an engineering focus.

As a guide, when marking we will take into account three factors:

* **extent of work:** did you do the expected amount of work (you won't get extra marks by doing a lot more than others, this is not a race). We will try to give an indication of this in the module description.
* **correctness/technical understanding:** is your solution and description of findings technincally correct, does it demonstrate learning and understanding of the topics we cover, and an ability to do independent reading if needed?
* **presentation:** How is the mini-project written up? Please focus on the writeup being short, to the point, well structured. Are figures well formatted, so it's clear what's shown on them (e.g. are there axis labels)?

You can choose whichever project you want to attempt. You can attempt more than one, but we will only mark one. **Please clearly state which of the mini-projects you would like us to mark**, if this is unclear, we will mark whichever appears fist in your submitted notebook.

# F.1: Compositional Generalisation in Group Structured Data
_[40 marks] 4-page writeup plus appendices_

The starting point for this exercise is Section E, where you trained models to predict missing entries of the Cayley table for $\mathbb{Z}_{97}$ given only a random subset of entries. If your models worked well there, and you like groups, this mini-project might appeal to you.

There is a lot of interesting compositional structure in groups. The 97-element cyclic group is rather boring, it is a simple group, meaning its only subgroups are the trivial group $\{e\}$ and itself. This is because $97$ is a prime number. Cyclic groups with non-prime size $n$, on the other hand, have more intereseting subgroup structure: For every divisor $d$ of $n$, there is exactly one subgroup of order (size) $d$. Symmetric groups (groups over permutations) have even more interesting subgroup structure. For example, $S_4$ contains 24 elements, and several subgoups of sizes corresponding to every divisor of 24. See details [on Wikipedia](https://en.wikipedia.org/wiki/Subgroup#Example:_Subgroups_of_S4). Elementary Albelian groups $\mathbb{Z}_p^k$ are another example of groups with rich subgroup structure.

In this mini-project, your task is to collect evidence relating to the following hypothesis:

_Hypothesis:_ Transformers (and/or other models) are able to learn the full Cayley table of rich groups like $S_n$ or $\mathbb{Z}_p^k$ seeing data only from certain sets of subgroups (i.e. specific slices of the Cayley table).

We know that small transformers are able to learn the Cayley table of $S_5$ when a small (less than 50%) fraction of randomly selected entries is provided as training data. To test the hypothesis above, you would design a different training-test split strategy which selects the union of certain subgroups as training data, with the remaining entries as test data. Is sampling subgroups better or worse than random sampling?

High marks for a complete investigation, including commentary about what your findings might reveal about compositional and statistical generalisation in Transformers and/or other models. It is expected that you put more effort into tuning your hyperparameters and improving your models than what is expected for part E alone.

##Groups

In [1]:
from typing import Iterable

In [2]:
class Permutation(tuple):
    def __getitem__(self, indices: int | slice | Iterable):
        if type(indices) in (int, slice):
            return tuple.__getitem__(self, indices)

        assert isinstance(indices, Permutation)
        return Permutation([self[i] for i in indices])

    def __matmul__(self, other):  # f@g, composition
        return self[other]

    def __pow__(self, exponent):  # repeated composition
        if exponent == 0:
            return Permutation(range(len(self)))  # identity
        return self @ self**(exponent-1)

    def __invert__(self):  # equivalent to argsort
        return Permutation([self.index(i) for i in range(len(self))])

    def __call__(self, other: int | Iterable):
        return self[other]

    def insert(self, index, value: int):
        assert value == len(self)
        return Permutation(self[:index] + (value,) + self[index:])

    def __eq__(self, other):
        return type(other)==type(self) and tuple.__eq__(self, other)
    def __ne__(self, other):
        return not self==other
    def __hash__(self):
        return hash((type(self), tuple.__hash__(self)))

    def __repr__(self):
        return tuple.__repr__(self).replace("(", "⟦").replace(",)", "⟧").replace(")", "⟧")

In [3]:
class Tuple(tuple):
    def __new__(cls, iterable=None):
        if iterable is None:
            iterable = ()
        return super().__new__(cls, tuple(item for item in iterable if item!=Tuple()))

    def __invert__(self):
        return Tuple(~item for item in self)

    def __matmul__(self, other):
        assert all(len(p1)==len(p2) for p1, p2 in zip(self, other))
        return Tuple(p1@p2 for p1, p2 in zip(self, other))

    def __eq__(self, other):
        return type(other)==type(self) and tuple.__eq__(self, other)
    def __ne__(self, other):
        return not self==other
    def __hash__(self):
        return hash((type(self), tuple.__hash__(self)))

    def __repr__(self):
        return tuple.__repr__(self).replace("(", "⦅").replace(",)", "⦆").replace(")", "⦆")

In [4]:
class Group(set):
    def __init__(self, iterable: Iterable[Permutation | Tuple], check_valid=True):
        assert len(iterable) >= 1
        if len(iterable) == 1 and () in iterable:
            iterable = [Tuple()]
        super().__init__(set(iterable))

        if check_valid==False or (len(iterable)==1 and Tuple() in iterable):
            return

        assert all(~obj in self for obj in self)  # inverses
        assert all(objSecond@objFirst in self for objFirst, objSecond in self**2)  # closure;  given inverses, this implies an identity
        assert all((objThird@objSecond)@objFirst == objThird@(objSecond@objFirst)  for objFirst, (objSecond, objThird) in self**3) # associativity


    def __mul__(self, other):  # cartesian product
        assert isinstance(other, Group)   # then the product is a valid group...
        return Group({Tuple((obj1, obj2)) for obj1 in self for obj2 in other}, check_valid=False) # ...so we don't need to check again

    def __iter__(self):
        return set.__iter__(self)

    def __pow__(self, exponent):
        if exponent==1:
            return self
        if exponent == 0:
            return Group({()})
        return self * self**(exponent-1)

    def __repr__(self):
        if len(self) == 0:
            return "⦃ ⦄"
        return f"⦃{repr(set(self))[1:-1]}⦄"

    @property
    def table(self):
        return {(obj2, obj1) : obj2@obj1  for obj1, obj2 in self**2}

In [5]:
def S(n):  # the Symmetric Group (i.e. the set of permutations) of {0, ..., n-1}
    if n == 0:
        return Group({Permutation([])})
    return Group({perm.insert(index, n-1)  for perm in S(n-1) for index in range(n)}, check_valid=False)  # we know S_n is a valid group!

##Transformer

In [6]:
import torch
from torch.utils.data import TensorDataset, Dataset, DataLoader, ChainDataset, random_split
from torch import nn
from torch.nn import functional as F
import random
from math import factorial
from google.colab import drive

def random_batch_split(num_batches, train): # random_split can take [0.1]*10, but seems to fail for [0.05]*20, so here's a workaround
    if num_batches <= 10:
        batch_list = num_batches * [1/num_batches]
    else:
        batch_list = num_batches * [len(train)//num_batches]
        num_remaining = len(train) - sum(batch_list)
        for i in range(num_remaining):
            batch_list[i] += 1
    return random_split(train, batch_list, generator=torch.Generator().manual_seed(10))

In [7]:
class EncodedSymmetricGroupTransformer(nn.Module): # with each permutation represented by a single integer
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        transformer_layer = nn.TransformerEncoderLayer(embed_dim, num_heads, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers, enable_nested_tensor=False)
        self.linear = nn.Linear(embed_dim, vocab_size)
        self.pos_embedding = nn.Parameter(torch.zeros(2, embed_dim))

    def forward(self, x):
        embedding = self.embedding(x)
        token_emb = embedding + self.pos_embedding[:embedding.size(1), :]
        encoded = self.transformer_encoder(token_emb.transpose(0, 1))
        return self.linear(encoded.mean(dim=0))

In [79]:
def format_losses(epoch, training_loss, validation_loss):
    return f"   ({epoch}, {training_loss:.4f}, {validation_loss:.4f})"


def train_model(model, train, validation, num_epochs, learning_rate, num_batches=5, weight_decay=0.05, betas=(0.95, 0.999)):
    validation_inputs, validation_labels = validation[::]
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=betas)
    loss_fn = nn.CrossEntropyLoss()

    def print_losses(model, epoch):
        model.eval()
        with torch.no_grad():
            avg_training_loss, validation_loss = (loss_fn(model(train[::][0][:50000]), train[::][1][:50000]),
                                                  loss_fn(model(validation_inputs[:50000]), validation_labels[:50000]))
        print(format_losses(epoch, avg_training_loss, validation_loss), end="")
        model.train()

    print("(Epoch, Training Loss, Validation Loss):")
    print_losses(model, 0)

    for epoch in range(1, num_epochs+1):
        optimizer.zero_grad()
        for batch in random_batch_split(num_batches, train):
            input_batch, label_batch = batch[::]
            training_loss = loss_fn(model(input_batch), label_batch)
            training_loss.backward()
            optimizer.step()

        if 5*epoch//num_epochs > 5*(epoch-1)//num_epochs:
            print_losses(model, epoch)

    model.eval()
    with torch.no_grad():
        probs = F.softmax(model(validation_inputs[:50000]), dim=-1)
        predicted = probs.argmax(dim=-1)
        accuracy = (predicted==validation_labels[:50000]).float().mean().item()
        print('\n    Accuracy of "best-guess"es on the validation set:"', f"{accuracy*100:.2f}%", "\n")

In [29]:
class SymmetricGroupDataset(TensorDataset):
    def __init__(self, base_set_size, alphabet=None):  # `base_set_size` i.e. the `n` in S_n
        if alphabet is None:
            alphabet = list(range(factorial(base_set_size)))
        random.seed(10)
        random.shuffle(alphabet)
        self.alphabet = alphabet
        self.group = S(base_set_size)

        self.mapping_dict = {}
        self.mapping_inverse_dict = {}
        for permutation, token in zip(self.group, self.alphabet):
            self.mapping_dict[token] = permutation
            self.mapping_inverse_dict[permutation] = token

        inputs = []
        outputs = []
        for (perm1, perm2), perm3 in self.group.table.items():
            inputs.append([self.unmap(perm1), self.unmap(perm2)])
            outputs.append(self.unmap(perm3))

        self.tensors = (torch.tensor(inputs), torch.tensor(outputs))   # self.tensors is used by parent class

    def map(self, token):
        return self.mapping_dict[token]

    def unmap(self, permutation):
        return self.mapping_inverse_dict[permutation]

    def cuda(self):
        self.tensors = (self.tensors[0].cuda(), self.tensors[1].cuda())
        return self

In [10]:
subgroups_s5 = [set(), set(), set(), set(), set(), set()]

for perm in S(6):
    for i, element in enumerate(perm):
        if i==element:
            subgroups_s5[i].add(perm)

derangements = set(S(6)) - set().union(*subgroups_s5)
subgroups_s5 = [Group(s, check_valid=False) for s in subgroups_s5]

In [31]:
base_set_size = 6
whole_dataset = SymmetricGroupDataset(base_set_size)
u = whole_dataset.unmap

training_set, validation_set, holdout_set = random_split([(i, j) for i in range(720) for j in range(720)],
                                                [0.5, 0.1, 0.4], generator=torch.Generator().manual_seed(10))
training_set = set(training_set)
validation_set = set(validation_set)
holdout_set = set(holdout_set)
# note that for now, training and validation only happen on the intersection with subgroups_s5,
# so in reality the holdout set is actually much bigger

subgroupsTrain = []
subgroupsValidation = []

for i in range(6):
    train_inputs = []
    train_labels = []
    validation_inputs = []
    validation_labels = []
    table = list(subgroups_s5[i].table.items())
    random.shuffle(table)
    for (p1, p2), p3 in table:
        if (u(p1), u(p2)) in training_set:
            train_inputs.append((u(p1), u(p2)))
            train_labels.append(u(p3))
        elif (u(p1), u(p2)) in validation_set:
            validation_inputs.append((u(p1), u(p2)))
            validation_labels.append(u(p3))

    subgroupsTrain.append(TensorDataset(torch.tensor(train_inputs).cuda(), torch.tensor(train_labels).cuda()))
    subgroupsValidation.append(TensorDataset(torch.tensor(validation_inputs).cuda(), torch.tensor(validation_labels).cuda()))

In [164]:
model = EncodedSymmetricGroupTransformer(vocab_size=factorial(base_set_size), embed_dim=64, num_heads=4, num_layers=3, dropout=0.2).cuda()

In [165]:
train_model(model, subgroupsTrain[0], subgroupsValidation[0], num_epochs=1000, learning_rate=0.003, num_batches=1, weight_decay=1.5)

(Epoch, Training Loss, Validation Loss):
   (0, 6.7095, 6.7087)   (200, 3.9393, 5.4489)   (400, 0.3771, 7.0608)   (600, 0.0561, 4.3936)   (800, 0.0313, 0.1549)   (1000, 0.0434, 0.0524)
    Accuracy of "best-guess"es on the validation set:" 100.00% 



In [166]:
#torch.save(model.state_dict(), "/content/drive/MyDrive/DNNs2/earlyStage_state_dict.pth")

In [184]:
model.load_state_dict(torch.load("/content/drive/MyDrive/DNNs2/earlyStage_state_dict.pth"))

<All keys matched successfully>

In [185]:
for i in range(1, 6):
    train_model(model, subgroupsTrain[i], subgroupsValidation[i], num_epochs=500, learning_rate=0.0006, num_batches=1, weight_decay=0.5)

(Epoch, Training Loss, Validation Loss):
   (0, 8.8561, 8.8484)   (100, 4.6541, 4.7725)   (200, 3.6789, 3.8957)   (300, 3.0644, 3.1324)   (400, 1.9544, 2.0055)   (500, 0.6074, 0.6373)
    Accuracy of "best-guess"es on the validation set:" 99.32% 

(Epoch, Training Loss, Validation Loss):
   (0, 10.1599, 10.1220)   (100, 4.4720, 4.5686)   (200, 3.3447, 3.6584)   (300, 2.5849, 2.9697)   (400, 1.4929, 1.6129)   (500, 0.4766, 0.5305)
    Accuracy of "best-guess"es on the validation set:" 99.51% 

(Epoch, Training Loss, Validation Loss):
   (0, 10.4990, 10.5969)   (100, 4.3033, 4.4606)   (200, 2.8723, 3.2337)   (300, 1.5258, 1.5947)   (400, 0.6681, 0.7048)   (500, 0.2184, 0.2413)
    Accuracy of "best-guess"es on the validation set:" 99.93% 

(Epoch, Training Loss, Validation Loss):
   (0, 10.4515, 10.4012)   (100, 3.1577, 3.3768)   (200, 1.4835, 1.5819)   (300, 0.5806, 0.6129)   (400, 0.1665, 0.1850)   (500, 0.0520, 0.0631)
    Accuracy of "best-guess"es on the validation set:" 100.00% 

(

In [186]:
def combineDatasets(datasets):  # because I can't get Torch's version(s) to work...
    return TensorDataset(torch.cat([d.tensors[0] for d in datasets]), torch.cat([d.tensors[1] for d in datasets]))

In [187]:
#torch.save(model.state_dict(), "/content/drive/MyDrive/DNNs2/midStage_state_dict.pth")

In [341]:
model.load_state_dict(torch.load("/content/drive/MyDrive/DNNs2/midStage_state_dict.pth"))

<All keys matched successfully>

In [342]:
train_model(model, combineDatasets(subgroupsTrain), combineDatasets(subgroupsValidation), num_epochs=300,
            learning_rate=0.002, num_batches=1, weight_decay=0.1)

(Epoch, Training Loss, Validation Loss):
   (0, 6.7663, 6.7231)   (60, 1.7538, 1.8973)   (120, 0.5255, 0.7075)   (180, 0.1022, 0.1904)   (240, 0.0190, 0.0401)   (300, 0.0062, 0.0137)
    Accuracy of "best-guess"es on the validation set:" 99.99% 



In [346]:
#torch.save(model.state_dict(), "/content/drive/MyDrive/DNNs2/midLateStage_state_dict.pth")

In [345]:
table = S(6).table
unseen_inputs = []
unseen_labels = []

for (u_p1, u_p2), u_p3 in set(map(lambda t: ((u(t[0][0]), u(t[0][1])), u(t[1])), table.items()))   \
                       - (set(zip(train_inputs, train_labels)) | set(zip(validation_inputs, validation_labels))):
    unseen_inputs.append((u_p1, u_p2))
    unseen_labels.append(u_p3)

In [360]:
unseen = TensorDataset(torch.tensor(unseen_inputs).cuda(), torch.tensor(unseen_labels).cuda())
train, validation, holdout = random_split(unseen, [0.035, 0.005, 0.96])

In [396]:
model.load_state_dict(torch.load("/content/drive/MyDrive/DNNs2/midLateStage_state_dict.pth"))

<All keys matched successfully>

In [None]:
train_model(model, train, validation, num_epochs=1500, learning_rate=0.002, num_batches=1, weight_decay=0.0001)

(Epoch, Training Loss, Validation Loss):
   (0, 9.3693, 9.4054)   (300, 0.0065, 0.0445)   (600, 0.0004, 0.0047)

In [384]:
#import pickle
#with open("/content/drive/MyDrive/DNNs2/unmap.pkl", "wb") as f:
#    pickle.dump(whole_dataset.mapping_inverse_dict, f)

In [117]:
#torch.save(model.state_dict(), "/content/drive/MyDrive/DNNs2/final_model_state_dict.pth")

In [393]:
holdout_inputs, holdout_labels = holdout[::]

model.eval()
with torch.no_grad():
    # (just splitting up the data, otherwise it doesn't fit in the model)
    probs = F.softmax(torch.cat([model(input_chunk) for input_chunk in torch.split(holdout_inputs, 50000)]), dim=-1)
    predicted = probs.argmax(dim=-1)
    accuracy = (predicted==holdout_labels).float().mean().item()
    print('\tAccuracy of "best-guess"es on the holdout set:"', f"{accuracy*100:.4f}%", "\n\n")

	Accuracy of "best-guess"es on the holdout set:" 99.9800% 




In [133]:
unscheduledModel = EncodedSymmetricGroupTransformer(vocab_size=factorial(base_set_size), embed_dim=64, num_heads=4,
                                                    num_layers=3, dropout=0.2).cuda()

train_model(unscheduledModel, train, validation, num_epochs=1000, learning_rate=0.0035,
            num_batches=1, weight_decay=0.01)

(Epoch, Training Loss, Validation Loss):
   (0, 6.7037, 6.6955)   (200, 0.0247, 14.5634)   (400, 0.0124, 17.4340)   (600, 0.0008, 18.7432)   (800, 0.0000, 19.6541)   (1000, 0.0000, 20.4355)
    Accuracy of "best-guess"es on the validation set:" 0.08% 



In [135]:
unscheduledSmallModel = EncodedSymmetricGroupTransformer(vocab_size=factorial(base_set_size), embed_dim=16, num_heads=4,
                                                       num_layers=3, dropout=0.5).cuda()

train_model(unscheduledSmallModel, train, validation, num_epochs=1000, learning_rate=0.001,
            num_batches=5, weight_decay=1)

(Epoch, Training Loss, Validation Loss):
   (0, 6.7126, 6.7279)   (200, 5.9495, 7.0703)   (400, 5.1723, 7.8606)   (600, 5.1170, 7.8337)   (800, 5.1302, 7.8067)   (1000, 5.1138, 7.7529)
    Accuracy of "best-guess"es on the validation set:" 0.20% 



NOOOO! The comparison (which'll be quicker) should be with "a randomly selected training sample of *the same length*.


Going for "way bigger dataset and still worse!!" is cool, but it's so expensive to compute, and the former claim is still significant.