In [None]:
import sys
sys.path.append("..")
import torch
from torch import nn
from torch import optim
import torchvision.transforms as transforms
import time
import os
# convolution patch embedding
# from Tensorized_components.patch_embedding  import Patch_Embedding        
from Tensorized_components.tcl_patch_embedding  import  PatchEmbedding  as  Patch_Embedding      
from Tensorized_components.w_msa_w_o_b_sign  import WindowMSA     
from Tensorized_components.sh_wmsa_w_o_b_sign import ShiftedWindowMSA     
from Tensorized_components.patch_merging  import TensorizedPatchMerging  
from Tensorized_Layers.TCL_CHANGED import TCL_CHANGED   
from Tensorized_Layers.TRL import TRL   
from Utils.Accuracy_measures import topk_accuracy
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders
from Utils.Num_parameter import count_parameters

In [2]:
class SwinBlock1(nn.Module):
    """
    A class representing 'Block 1' in your Swin Transformer.
    This captures the sequence of:
        (1) Window MSA + residual
        (2) TCL + residual
        (3) Shifted Window MSA + residual
        (4) TCL + residual
    but only for the first block’s hyperparameters and submodules.
    """
    def __init__(self, w_msa, sw_msa, trl1,trl2,trl3,trl4, embed_shape, dropout=0):
        super(SwinBlock1, self).__init__()
        # Typically each sub-layer has its own LayerNorm
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # We pass in pre-built modules (WindowMSA, ShiftedWindowMSA, TCL)
        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.trl1 = trl1
        self.gelu = nn.GELU()
        self.trl2 = trl2
        self.trl3 = trl3
        self.trl4 = trl4
    def forward(self, x):
        # ----- First Window MSA + Residual -----
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        # ----- TCL + Residual -----
        x_res = x
        x = self.norm2(x)
        x = self.trl1(x)
        x = self.gelu(x)
        x = self.trl2(x)
        x = x + x_res

        # ----- Shifted Window MSA + Residual -----
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        # ----- TCL + Residual -----
        x_res = x
        x = self.norm4(x)
        x = self.trl3(x)
        x = self.gelu(x)
        x = self.trl4(x)
        x = x + x_res

        return x


In [3]:
class SwinBlock2(nn.Module):
    def __init__(self, w_msa, sw_msa, trl1, trl2 , trl3 , trl4,  embed_shape=(4,4,6), dropout=0):
        super(SwinBlock2, self).__init__()
        # LN layers
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.trl1 = trl1
        self.trl2 = trl2
        self.trl3 = trl3
        self.trl4 = trl4
        self.gelu = nn.GELU()


    def forward(self, x):
        # Window MSA
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        # TCL
        x_res = x
        x = self.norm2(x)
        x = self.trl1(x)
        x = self.gelu(x)
        x = self.trl2(x)
        x = x + x_res

        # Shifted Window MSA
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        # TCL
        x_res = x
        x = self.norm4(x)
        x = self.trl3(x)
        x = self.gelu(x)
        x = self.trl4(x)
        x = x + x_res

        return x


In [4]:
class SwinBlock3(nn.Module):
    def __init__(self, w_msa, sw_msa, trl1,  trl2 , trl3 , trl4,   embed_shape=(4,4,12), dropout=0):
        super(SwinBlock3, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.trl1 = trl1
        self.gelu = nn.GELU()
        self.trl2 = trl2
        self.trl3 = trl3
        self.trl4 = trl4

    def forward(self, x):
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm2(x)
        x = self.trl1(x)
        x = self.gelu(x)
        x = self.trl2(x)
        x = x + x_res

        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm4(x)
        x = self.trl3(x)
        x = self.gelu(x)
        x = self.trl4(x)
        x = x + x_res
        return x

In [5]:
class SwinBlock4(nn.Module):
    def __init__(self, w_msa, sw_msa, trl1, trl2 , trl3 , trl4 ,  embed_shape=(4,4,24), dropout=0):
        super(SwinBlock4, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)
        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.trl1 = trl1
        self.trl2 = trl2
        self.trl3 = trl3
        self.trl4 = trl4
        self.gelu = nn.GELU()

    def forward(self, x):
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm2(x)
        x = self.trl1(x)
        x = self.gelu(x)
        x = self.trl2(x)
        x = x + x_res

        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x + x_res

        x_res = x
        x = self.norm4(x)
        x = self.trl3(x)
        x = self.gelu(x)
        x = self.trl4(x)
        x = x + x_res

        return x


In [6]:
class SwinTransformer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_shape=(4,4,12),
                 bias=True,
                 dropout=0,
                 device="cuda"):
        super(SwinTransformer, self).__init__()

        self.device = device

# tcl patch embedding 


# TODO : change batch size and device to cuda
        self.patch_embedding = Patch_Embedding(
            input_size=(16,3,224,224),
            patch_size=patch_size,
            embed_dim=embed_shape,
            bias=bias,
            device="cpu",
            ignore_modes = (0,1,2)
        )
# convolution 
        # self.patch_embedding = Patch_Embedding(
        #     img_size=img_size,
        #     patch_size=patch_size,
        #     in_chans=in_chans,
        #     embed_shape=embed_shape,
        #     bias=bias
        # )

        # -------------------------------- block 1 --------------------------

        self.w_msa_1 = WindowMSA(
            window_size=7,
            embed_dims=embed_shape,
            rank_window=embed_shape,
            head_factors=(1,2,3),
            device=self.device
        )

        self.sw_msa_1 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=embed_shape,
            rank_window=embed_shape,
            head_factors=(1,2,3),
            device=self.device
        )


    # def __init__(self, input_size, output, rank, ignore_modes = (0,), bias = True, device = 'cuda'):

        self.trl_1 = TRL(
            input_size=(16, 56, 56, 4,4,12),
            output=(4,4,48),
            rank=(4,4,12,4,4,48),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.trl_1_2 = TRL(
            input_size=(16, 56, 56, 4,4,48),
            output=(4,4,12),
            rank=(4,4,48,4,4,12),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.trl_1_3 = TRL(
            input_size=(16, 56, 56, 4,4,12),
            output=(4,4,48),
            rank=(4,4,12,4,4,48),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.trl_1_4 = TRL(
            input_size=(16, 56, 56, 4,4,48),
            output=(4,4,12),
            rank=(4,4,48,4,4,12),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )


        self.block1_list = nn.ModuleList([
            SwinBlock1(
                w_msa=self.w_msa_1,
                sw_msa=self.sw_msa_1,
                trl1=self.trl_1,
                trl2 = self.trl_1_2,
                trl3 = self.trl_1_3,
                trl4 = self.trl_1_4,
                embed_shape=embed_shape,
                dropout=dropout
            )
            for _ in range(2)
        ])

        # -------------------------------- block 2 --------------------------


        self.patch_merging_1 = TensorizedPatchMerging(
            input_size=(16, 56, 56, 4,4,12),
            in_embed_shape=embed_shape,
            out_embed_shape=(4,4,24),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )

        self.w_msa_2 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,24),
            rank_window=(4,4,24),
            head_factors=(1,2,6),
            device=self.device
        )

        self.sw_msa_2 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,24),
            rank_window=(4,4,24),
            head_factors=(1,2,6),
            device=self.device
        )



        self.trl_2 = TRL(
            input_size=(16, 28, 28, 4,4,24),
            output=(4,4,96),
            rank=(4,4,24,4,4,96),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )


        self.trl_2_2 = TRL(
            input_size=(16, 28, 28, 4,4,96),
            output=(4,4,24),
            rank=(4,4,96,4,4,24),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.trl_2_3 = TRL(
            input_size=(16, 28, 28, 4,4,24),
            output=(4,4,96),
            rank=(4,4,24,4,4,96),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )
        
        self.trl_2_4 = TRL(
            input_size=(16, 28, 28, 4,4,96),
            output=(4,4,24),
            rank=(4,4,96,4,4,24),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )
        # We repeat Block2 two times
        self.block2_list = nn.ModuleList([
            SwinBlock2(
                w_msa=self.w_msa_2,
                sw_msa=self.sw_msa_2,
                trl1=self.trl_2,
                trl2 = self.trl_2_2,
                trl3 = self.trl_2_3,
                trl4 = self.trl_2_4,
                embed_shape=(4,4,24),  
                dropout=dropout
            )
            for _ in range(2)
        ])


        # # -------------------------------- block 3 --------------------------

        self.patch_merging_2 = TensorizedPatchMerging(
            input_size=(16, 28, 28, 4,4,24),
            in_embed_shape=(4,4,24),
            out_embed_shape=(4,4,48),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )


        self.w_msa_3 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,48),
            rank_window=(4,4,48),
            head_factors=(2,1,12),
            device=self.device
        )

        self.sw_msa_3 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,48),
            rank_window=(4,4,48),
            head_factors=(2,1,12),
            device=self.device
        )


        self.trl_3 = TRL(
            input_size=(16, 14, 14, 4,4,48),
            output=(4,4,192),
            rank=(4,4,48,4,4,192),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )
        self.trl_3_2 = TRL(
            input_size=(16, 14, 14, 4,4,192),
            output=(4,4,48),
            rank=(4,4,192 ,4,4,48),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.trl_3_3 = TRL(
            input_size=(16, 14, 14, 4,4,48),
            output=(4,4,192),
            rank=(4,4,48,4,4,192),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )
        self.trl_3_4 = TRL(
            input_size=(16, 14, 14, 4,4,192),
            output=(4,4,48),
            rank=(4,4,192 ,4,4,48),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )


        # Repeat Block3 6 times
        self.block3_list = nn.ModuleList([
            SwinBlock3(
                w_msa=self.w_msa_3,
                sw_msa=self.sw_msa_3,
                trl1=self.trl_3,
                trl2=self.trl_3_2,
                trl3=self.trl_3_3,
                trl4=self.trl_3_4,
                embed_shape=(4,4,48),
                dropout=dropout
            )
            for _ in range(18)
        ])

        # # # -------------------------------- block 4 --------------------------

        self.patch_merging_3 = TensorizedPatchMerging(
            input_size=(16, 14, 14, 4,4,48),
            in_embed_shape=(4,4,48),
            out_embed_shape=(4,4,96),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )

        self.w_msa_4 = WindowMSA(
            window_size=7,
            embed_dims=(4,4,96),
            rank_window=(4,4,96),
            head_factors=(2,1,24),
            device=self.device
        )

        self.sw_msa_4 = ShiftedWindowMSA(
            window_size=7,
            embed_dims=(4,4,96),
            rank_window=(4,4,96),
            head_factors=(2,1,24),
            device=self.device
        )


        self.trl_4 = TRL(
            input_size=(16, 7, 7, 4,4,96),
            output=(4,4,384),
            rank=(4,4,96,4,4,384),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.trl_4_2 = TRL(
            input_size=(16, 7, 7, 4,4,384),
            output = (4,4,96),
            rank=(4,4,384,4,4,96),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )


        self.trl_4_3 = TRL(
            input_size=(16, 7, 7, 4,4,96),
            output=(4,4,384),
            rank=(4,4,96,4,4,384),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )


        self.trl_4_4 = TRL(
            input_size=(16, 7, 7, 4,4,384),
            output = (4,4,96),
            rank=(4,4,384,4,4,96),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.block4_list = nn.ModuleList([
            SwinBlock4(
                w_msa=self.w_msa_4,
                sw_msa=self.sw_msa_4,
                trl1=self.trl_4,
                trl2=self.trl_4_2,
                trl3=self.trl_4_3,
                trl4=self.trl_4_4,
                embed_shape=(4,4,96),
                dropout=dropout
            )
            for _ in range(2)
        ])

        # -------------------------------- classifier --------------------------

    

        self.classifier = TRL(input_size=(16,4,4,96),
                            output=(200,),
                            rank=(4,4,96,200),
                            ignore_modes=(0,),
                            bias=bias,
                            device=self.device) 
        

        # positoin embedding


        self.pos_embedding = nn.Parameter(
            torch.randn(1,
                        56,
                        56,
                        4,
                        4,
                        12,
                        device = self.device
                        ), requires_grad=True)

    def forward(self, x):
 

        x = self.patch_embedding(x)

        x += self.pos_embedding

        for i, blk in enumerate(self.block1_list, 1):
            x = blk(x)


        x = self.patch_merging_1(x)



        for i, blk in enumerate(self.block2_list, 1):
            x = blk(x)


        x = self.patch_merging_2(x)

        for i, blk in enumerate(self.block3_list, 1):
            x = blk(x)


        x = self.patch_merging_3(x)


        for i, blk in enumerate(self.block4_list, 1):
            x = blk(x)


        x = x.mean(dim=(1, 2))

        output = self.classifier(x)
        return output

In [None]:
# Setup the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f'Device is set to : {device}')

# Configs

TEST_ID = 'Test_ID00034'
batch_size = 16
n_epoch = 400
image_size = 224

model = SwinTransformer(img_size=224,patch_size=4,in_chans=3,embed_shape=(4,4,12),bias=True,device=device).to(device)


# Set up the transforms and train/test loaders

tiny_transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((image_size, image_size)), 
        transforms.RandomCrop(image_size, padding=5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_test = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])


train_loader, val_loader , test_loader = get_tinyimagenet_dataloaders(
                                                    data_dir = '../datasets',
                                                    transform_train=tiny_transform_train,
                                                    transform_val=tiny_transform_val,
                                                    transform_test=tiny_transform_test,
                                                    batch_size=batch_size,
                                                    image_size=image_size)

Device is set to : cpu


: 

In [None]:
num_parameters = count_parameters(model)
print(f'This Model has {num_parameters} parameters')
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())


# Define train and test functions (use examples)
def train_epoch(loader, epoch):
    model.train()

    start_time = time.time()
    running_loss = 0.0
    correct = {1:0.0, 2:0.0, 3:0.0, 4:0.0, 5:0.0} # set the initial correct count for top1-to-top5 accuracy

    for i, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
    
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
    
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        accuracies = topk_accuracy(outputs, targets, topk=(1, 2, 3, 4, 5))
        for k in accuracies:
            correct[k] += accuracies[k]['correct']
        # print(f'batch{i} done!')

    elapsed_time = time.time() - start_time
    top1_acc, top2_acc, top3_acc, top4_acc, top5_acc = [(correct[k]/len(loader.dataset)) for k in correct]
    avg_loss = running_loss / len(loader.dataset)

    report_train = f'Train epoch {epoch}: top1={top1_acc}%, top2={top2_acc}%, top3={top3_acc}%, top4={top4_acc}%, top5={top5_acc}%, loss={avg_loss}, time={elapsed_time}s'
    print(report_train)

    return report_train

def test_epoch(loader, epoch):
    model.eval()

    start_time = time.time()
    running_loss = 0.0
    correct = {1:0.0, 2:0.0, 3:0.0, 4:0.0, 5:0.0} # set the initial correct count for top1-to-top5 accuracy

    for _, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        running_loss += loss.item()
        accuracies = topk_accuracy(outputs, targets, topk=(1, 2, 3, 4, 5))
        for k in accuracies:
            correct[k] += accuracies[k]['correct']

    elapsed_time = time.time() - start_time
    top1_acc, top2_acc, top3_acc, top4_acc, top5_acc = [(correct[k]/len(loader.dataset)) for k in correct]
    avg_loss = running_loss / len(loader.dataset)

    report_test = f'Test epoch {epoch}: top1={top1_acc}%, top2={top2_acc}%, top3={top3_acc}%, top4={top4_acc}%, top5={top5_acc}%, loss={avg_loss}, time={elapsed_time}s'
    print(report_test)

    return report_test

# Set up the directories to save the results
result_dir = os.path.join('../results', TEST_ID)
result_subdir = os.path.join(result_dir, 'accuracy_stats')
model_subdir = os.path.join(result_dir, 'model_stats')

os.makedirs(result_subdir, exist_ok=True)
os.makedirs(model_subdir, exist_ok=True)

with open(os.path.join(result_dir, 'model_stats', 'model_info.txt'), 'a') as f:
    f.write(f'total number of parameters:\n{num_parameters}')

# Train from Scratch - Just Train
print(f'Training for {len(range(n_epoch))} epochs\n')
for epoch in range(0+1,n_epoch+1):
    report_train = train_epoch(train_loader, epoch)
    # report_test = test_epoch(test_loader, epoch)

    report = report_train + '\n' #+ report_test + '\n\n'
    if epoch % 5 == 0:
        model_path = os.path.join(result_dir, 'model_stats', f'Model_epoch_{epoch}.pth')
        torch.save(model.state_dict(), model_path)
    with open(os.path.join(result_dir, 'accuracy_stats', 'report_train.txt'), 'a') as f:
        f.write(report)
            

This Model has 52472588 parameters
Training for 400 epochs

