In [1]:
import sys
sys.path.append("..")
import torch
import numpy as np
from torch import nn
from torch import optim
import torchvision.transforms as transforms
import time
import os
from Tensorized_components.patch_embedding  import Patch_Embedding     
from Tensorized_components.w_msa_w_o_b_sign  import WindowMSA     
from Tensorized_components.sh_wmsa_w_o_b_sign import ShiftedWindowMSA     
from Tensorized_components.patch_merging  import TensorizedPatchMerging  
from Tensorized_Layers.TCL_CHANGED import TCL_CHANGED   
from Tensorized_Layers.TRL import TRL   
from Utils.Accuracy_measures import topk_accuracy
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders
from Utils.Num_parameter import count_parameters


In [2]:
# --------------------------------------------------------------------------------
# Utilities for DropPath (Stochastic Depth)
# --------------------------------------------------------------------------------
def drop_path(x, drop_prob: float = 0.0, training: bool = True):
    """Drop paths (Stochastic Depth) per sample (5D or 6D for your tensor shape).
    This function is generalized for your input shape. Adjust if needed."""
    if drop_prob == 0.0 or not training:
        return x
    
    keep_prob = 1.0 - drop_prob
    batch_size = x.shape[0]
    # For a 6D input, create a mask of shape (B, 1, 1, 1, 1, 1)
    random_tensor = keep_prob + torch.rand(
        (batch_size, ) + (1,) * (x.dim() - 1),
        dtype=x.dtype, device=x.device
    )
    random_tensor.floor_()
    x = x / keep_prob * random_tensor
    return x

class DropPath(nn.Module):
    """Wrapper module for drop_path function."""
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)



In [3]:
class SwinBlock1(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape, dropout=0.0, drop_path_rate=0.0):
        super(SwinBlock1, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)
        self.drop_path = DropPath(drop_path_rate)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # (1) Window MSA + residual
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x_res + self.drop_path(x)

        # (2) TCL + residual
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        # (3) Shifted Window MSA + residual
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x_res + self.drop_path(x)

        # (4) TCL + residual
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        return x


class SwinBlock2(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4, 4, 6), dropout=0.0, drop_path_rate=0.0):
        super(SwinBlock2, self).__init__()

        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)
        self.drop_path = DropPath(drop_path_rate)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # ----- Window MSA -----
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x_res + self.drop_path(x)

        # ----- TCL -----
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        # ----- Shifted Window MSA -----
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x_res + self.drop_path(x)

        # ----- TCL -----
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        return x


class SwinBlock3(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4,4,12),
                 dropout=0.0, drop_path_rate=0.0):
        super(SwinBlock3, self).__init__()

        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)
        self.drop_path = DropPath(drop_path_rate)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # 1) Window MSA + Residual
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x_res + self.drop_path(x)

        # 2) TCL + Residual
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        # 3) Shifted Window MSA + Residual
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x_res + self.drop_path(x)

        # 4) TCL + Residual
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        return x


class SwinBlock4(nn.Module):
    def __init__(self, w_msa, sw_msa, tcl, embed_shape=(4,4,24),
                 dropout=0.0, drop_path_rate=0.0):
        super(SwinBlock4, self).__init__()
        self.norm1 = nn.LayerNorm(embed_shape)
        self.norm2 = nn.LayerNorm(embed_shape)
        self.norm3 = nn.LayerNorm(embed_shape)
        self.norm4 = nn.LayerNorm(embed_shape)

        self.dropout = nn.Dropout(dropout)
        self.drop_path = DropPath(drop_path_rate)

        self.w_msa = w_msa
        self.sw_msa = sw_msa
        self.tcl = tcl

    def forward(self, x):
        # ---- Window MSA ----
        x_res = x
        x = self.norm1(x)
        x = self.dropout(self.w_msa(x))
        x = x_res + self.drop_path(x)

        # ---- TCL ----
        x_res = x
        x = self.norm2(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        # ---- Shifted Window MSA ----
        x_res = x
        x = self.norm3(x)
        x = self.dropout(self.sw_msa(x))
        x = x_res + self.drop_path(x)

        # ---- TCL ----
        x_res = x
        x = self.norm4(x)
        x = self.tcl(x)
        x = x_res + self.drop_path(x)

        return x

In [4]:
def create_drop_path_rates(num_blocks, max_drop=0.5):
    """
    Returns a list of drop path rates linearly increasing from 0 to max_drop,
    length = num_blocks.
    Example: if num_blocks=24 and max_drop=0.5, returns [0.0, 0.0217..., ..., 0.5].
    """
    return torch.linspace(0, max_drop, num_blocks).tolist()


In [5]:
class SwinTransformer(nn.Module):
    def __init__(self,
                 img_size=384,
                 patch_size=4,
                 in_chans=3,
                 embed_shape=(4,4,12),
                 bias=True,
                 dropout=0,
                 max_drop_path=0.2,
                 device="cuda"):
        super(SwinTransformer, self).__init__()
        self.device = device

        # (Optional) You can define the 'depths' of each stage. 
        # For instance, [2,2,18,2] for 'Swin-Large' style structure, 
        # but it can vary.
        self.depths = [2, 2, 18, 2]  # matches your example
        total_blocks = sum(self.depths)  # e.g. 24 in your case

        # Create a list of linearly spaced drop path rates from 0 to max_drop_path
        drop_path_rates = create_drop_path_rates(total_blocks, max_drop_path)

        # We'll keep track of an index so we can assign a unique drop path rate to each block
        dpr_index = 0

        # ------------------ Patch Embedding ------------------
        self.patch_embedding = Patch_Embedding(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_shape=embed_shape,
            bias=bias
        )

        # ------------------ Stage 1 Setup -------------------
        self.w_msa_1 = WindowMSA(
            window_size=12,
            embed_dims=embed_shape,
            rank_window=embed_shape,
            head_factors=(1,2,3),
            device=self.device
        )
        self.sw_msa_1 = ShiftedWindowMSA(
            window_size=12,
            embed_dims=embed_shape,
            rank_window=embed_shape,
            head_factors=(1,2,3),
            device=self.device
        )
        self.tcl_1 = TCL_CHANGED(
            input_size=(16, 96, 96, 4,4,12),
            rank=embed_shape,
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.block1_list = nn.ModuleList([
            SwinBlock1(
                w_msa=self.w_msa_1,
                sw_msa=self.sw_msa_1,
                tcl=self.tcl_1,
                embed_shape=embed_shape,
                dropout=dropout,
                drop_path_rate=drop_path_rates[dpr_index + i]
            )
            for i in range(self.depths[0])
        ])
        dpr_index += self.depths[0]

        self.patch_merging_1 = TensorizedPatchMerging(
            input_size=(16, 96, 96, 4,4,12),
            in_embed_shape=embed_shape,
            out_embed_shape=(4,4,24),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )

        # ------------------ Stage 2 Setup -------------------
        self.w_msa_2 = WindowMSA(
            window_size=12,
            embed_dims=(4,4,24),
            rank_window=(4,4,24),
            head_factors=(1,2,6),
            device=self.device
        )
        self.sw_msa_2 = ShiftedWindowMSA(
            window_size=12,
            embed_dims=(4,4,24),
            rank_window=(4,4,24),
            head_factors=(1,2,6),
            device=self.device
        )
        self.tcl_2 = TCL_CHANGED(
            input_size=(16, 48, 48, 4,4,24),
            rank=(4,4,24),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.block2_list = nn.ModuleList([
            SwinBlock2(
                w_msa=self.w_msa_2,
                sw_msa=self.sw_msa_2,
                tcl=self.tcl_2,
                embed_shape=(4,4,24),
                dropout=dropout,
                drop_path_rate=drop_path_rates[dpr_index + i]
            )
            for i in range(self.depths[1])
        ])
        dpr_index += self.depths[1]

        self.patch_merging_2 = TensorizedPatchMerging(
            input_size=(16, 48, 48, 4,4,24),
            in_embed_shape=(4,4,24),
            out_embed_shape=(4,4,48),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )

        # ------------------ Stage 3 Setup -------------------
        self.w_msa_3 = WindowMSA(
            window_size=12,
            embed_dims=(4,4,48),
            rank_window=(4,4,48),
            head_factors=(2,1,12),
            device=self.device
        )
        self.sw_msa_3 = ShiftedWindowMSA(
            window_size=12,
            embed_dims=(4,4,48),
            rank_window=(4,4,48),
            head_factors=(2,1,12),
            device=self.device
        )
        self.tcl_3 = TCL_CHANGED(
            input_size=(16, 24, 24, 4,4,48),
            rank=(4,4,48),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.block3_list = nn.ModuleList([
            SwinBlock3(
                w_msa=self.w_msa_3,
                sw_msa=self.sw_msa_3,
                tcl=self.tcl_3,
                embed_shape=(4,4,48),
                dropout=dropout,
                drop_path_rate=drop_path_rates[dpr_index + i]
            )
            for i in range(self.depths[2])
        ])
        dpr_index += self.depths[2]

        self.patch_merging_3 = TensorizedPatchMerging(
            input_size=(16, 24, 24, 4,4,48),
            in_embed_shape=(4,4,48),
            out_embed_shape=(4,4,96),
            bias=bias,
            ignore_modes=(0, 1, 2),
            device=self.device
        )

        # ------------------ Stage 4 Setup -------------------
        self.w_msa_4 = WindowMSA(
            window_size=12,
            embed_dims=(4,4,96),
            rank_window=(4,4,96),
            head_factors=(2,1,24),
            device=self.device
        )
        self.sw_msa_4 = ShiftedWindowMSA(
            window_size=12,
            embed_dims=(4,4,96),
            rank_window=(4,4,96),
            head_factors=(2,1,24),
            device=self.device
        )
        self.tcl_4 = TCL_CHANGED(
            input_size=(16, 12, 7, 4,4,96),
            rank=(4,4,96),
            ignore_modes=(0, 1, 2),
            bias=bias,
            device=self.device
        )

        self.block4_list = nn.ModuleList([
            SwinBlock4(
                w_msa=self.w_msa_4,
                sw_msa=self.sw_msa_4,
                tcl=self.tcl_4,
                embed_shape=(4,4,96),
                dropout=dropout,
                drop_path_rate=drop_path_rates[dpr_index + i]
            )
            for i in range(self.depths[3])
        ])
        dpr_index += self.depths[3]

        # ------------------ Classifier / Final Layer -------------------
        self.classifier = TRL(
            input_size=(16, 4, 4, 96),
            output=(200,),
            rank=(4,4,96,200),
            ignore_modes=(0,),
            bias=bias,
            device=self.device
        )

        # ------------------ Position Embedding -------------------
        # For example usage
        self.pos_embedding = nn.Parameter(
            torch.randn(
                1, 96, 96, 4, 4, 12, device=self.device
            ),
            requires_grad=True
        )

    def forward(self, x):
        # (1) Patch embedding
        x = self.patch_embedding(x)

        # (2) Add position embedding
        x = x + self.pos_embedding

        # ----- Stage 1 -----
        for blk in self.block1_list:
            x = blk(x)
        x = self.patch_merging_1(x)

        # ----- Stage 2 -----
        for blk in self.block2_list:
            x = blk(x)
        x = self.patch_merging_2(x)

        # ----- Stage 3 -----
        for blk in self.block3_list:
            x = blk(x)
        x = self.patch_merging_3(x)

        # ----- Stage 4 -----
        for blk in self.block4_list:
            x = blk(x)

        # Global average pooling (adapted for your tensor shape)
        # In your snippet, you used x.mean(dim=(1, 2)), but your shape might be 6D.
        # Adjust if necessary; for example, if your final shape is [B, H, W, ...], do:
        x = x.mean(dim=(1, 2))  # Example: if shape is (B, 7, 7, D1, D2, D3)

        # Final classifier
        output = self.classifier(x)
        return output

In [6]:

from torchvision.transforms.functional import InterpolationMode

# Setup the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f'Device is set to : {device}')

# Configs

TEST_ID = 'Test_ID011'
batch_size = 16
n_epoch = 60
image_size = 384

model = SwinTransformer(
        img_size=384,
        patch_size=4,
        in_chans=3,
        embed_shape=(4,4,12),
        bias=True,
        dropout=0.0,
        max_drop_path=0.2,   # This is your maximum drop path probability
        device=device
    ).to(device)


tiny_transform_train = transforms.Compose([
    transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(),  # <- Added: Enables RandAugment 
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                         std=(0.229, 0.224, 0.225)),
    transforms.RandomErasing(p=0.25)  # -> random_erasing with p=0.25
])

tiny_transform_val = transforms.Compose([
    transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                         std=(0.229, 0.224, 0.225)),
])

tiny_transform_test = transforms.Compose([
    transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                         std=(0.229, 0.224, 0.225)),
])


train_loader, val_loader , test_loader = get_tinyimagenet_dataloaders(
                                                    data_dir = '../datasets',
                                                    transform_train=tiny_transform_train,
                                                    transform_val=tiny_transform_val,
                                                    transform_test=tiny_transform_test,
                                                    batch_size=batch_size,
                                                    image_size=image_size)

Device is set to : cpu


In [7]:
def mixup_data(x, y, alpha=0.8):
    """Returns mixed inputs, pairs of targets, and lambda for mixup.
    x: input images
    y: labels
    alpha: mixup alpha value, set to 0.8
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0):
    """Returns cutmixed inputs, pairs of targets, and lambda for cutmix.
    x: input images
    y: labels
    alpha: cutmix alpha value, set to 1.0
    """
    batch_size, _, H, W = x.size()
    indices = torch.randperm(batch_size).to(x.device)
    shuffled_x = x[indices]
    shuffled_y = y[indices]
    
    lam = np.random.beta(alpha, alpha)
    # Determine cut dimensions
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    # Randomly choose the center of the box
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    # Compute the bounding box coordinates and make sure they are within image bounds
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    # Replace region in original x with region from shuffled x
    x[:, :, bby1:bby2, bbx1:bbx2] = shuffled_x[:, :, bby1:bby2, bbx1:bbx2]
    lam_adjusted = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
    return x, y, shuffled_y, lam_adjusted

In [None]:
num_parameters = count_parameters(model)
print(f'This Model has {num_parameters} parameters')
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())


def train_epoch(loader, epoch):
    model.train()
    start_time = time.time()
    running_loss = 0.0
    correct = {1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.0}  # for top1-to-top5 accuracy

    for i, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        # Randomly decide which augmentation to use on the mini-batch.
        if np.random.rand() < 0.5:
            # Use Mixup (mixup_alpha = 0.8)
            inputs_aug, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha=0.8)
            outputs = model(inputs_aug)
            loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
        else:
            # Use CutMix (cutmix_alpha = 1.0)
            inputs_aug, targets_a, targets_b, lam = cutmix_data(inputs, targets, alpha=1.0)
            outputs = model(inputs_aug)
            loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        accuracies = topk_accuracy(outputs, targets, topk=(1, 2, 3, 4, 5))
        for k in accuracies:
            correct[k] += accuracies[k]['correct']
        # Optional: print progress for each batch

    elapsed_time = time.time() - start_time
    top1_acc, top2_acc, top3_acc, top4_acc, top5_acc = [(correct[k] / len(loader.dataset)) for k in correct]
    avg_loss = running_loss / len(loader.dataset)

    report_train = (f'Train epoch {epoch}: top1={top1_acc*100:.2f}%, top2={top2_acc*100:.2f}%, '
                    f'top3={top3_acc*100:.2f}%, top4={top4_acc*100:.2f}%, top5={top5_acc*100:.2f}%, '
                    f'loss={avg_loss:.4f}, time={elapsed_time:.2f}s')
    print(report_train)
    return report_train
def test_epoch(loader, epoch):
    model.eval()

    start_time = time.time()
    running_loss = 0.0
    correct = {1:0.0, 2:0.0, 3:0.0, 4:0.0, 5:0.0} # set the initial correct count for top1-to-top5 accuracy

    for _, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        running_loss += loss.item()
        accuracies = topk_accuracy(outputs, targets, topk=(1, 2, 3, 4, 5))
        for k in accuracies:
            correct[k] += accuracies[k]['correct']

    elapsed_time = time.time() - start_time
    top1_acc, top2_acc, top3_acc, top4_acc, top5_acc = [(correct[k]/len(loader.dataset)) for k in correct]
    avg_loss = running_loss / len(loader.dataset)

    report_test = f'Test epoch {epoch}: top1={top1_acc}%, top2={top2_acc}%, top3={top3_acc}%, top4={top4_acc}%, top5={top5_acc}%, loss={avg_loss}, time={elapsed_time}s'
    print(report_test)

    return report_test

# Set up the directories to save the results
result_dir = os.path.join('../results', TEST_ID)
result_subdir = os.path.join(result_dir, 'accuracy_stats')
model_subdir = os.path.join(result_dir, 'model_stats')

os.makedirs(result_subdir, exist_ok=True)
os.makedirs(model_subdir, exist_ok=True)

with open(os.path.join(result_dir, 'model_stats', 'model_info.txt'), 'a') as f:
    f.write(f'total number of parameters:\n{num_parameters}')

# Train from Scratch - Just Train
print(f'Training for {len(range(n_epoch))} epochs\n')
for epoch in range(0+1,n_epoch+1):
    report_train = train_epoch(train_loader, epoch)
    # report_test = test_epoch(test_loader, epoch)

    report = report_train + '\n' #+ report_test + '\n\n'
    if epoch % 5 == 0:
        model_path = os.path.join(result_dir, 'model_stats', f'Model_epoch_{epoch}.pth')
        torch.save(model.state_dict(), model_path)
    with open(os.path.join(result_dir, 'accuracy_stats', 'report_train.txt'), 'a') as f:
        f.write(report)
            

This Model has 2424760 parameters
Training for 60 epochs



KeyboardInterrupt: 

: 