## Indroduction
The Swin Transformer (Shifted Window Transformer) is a type of vision transformer model that processes images by dividing them into small, non-overlapping windows and computes self-attention within these localized regions. Unlike standard vision transformers which use global attention, Swin Transformer introduces a "shifted window" technique. This allows neighboring windows to interact with each other in subsequent layers, efficiently capturing both local and global features in an image

## Architecture and Working of Swin Transformer

The Swin Transformer’s architecture is built on a combination of hierarchical design and window-based self-attention for efficient working and feature extraction.

![Swin Transformer](https://media.geeksforgeeks.org/wp-content/uploads/20250120164009277743/Architecture.webp)

## Swin Transformer block
![STB](https://amaarora.github.io/images/swin-transformer-block.png)

# How Swin Transformer Works

## 1. Patch Splitting
The input image is divided into fixed-size patches, similar to placing a grid over the image where each square represents a patch.  
Each patch is then embedded into a feature vector, which forms the input sequence for the transformer.

---

## 2. Window-Based Self-Attention
Instead of computing self-attention over the entire image (global attention), the Swin Transformer computes self-attention **within local windows**.

- Each window acts as a small focused region.
- This helps capture fine-grained local features.
- Limiting attention to windows significantly reduces computational cost.

Self-attention is applied independently inside each window.

---

## 3. Shifted Windows for Cross-Region Interaction
Local window-based attention alone cannot model relationships across different windows.  
The **shifted window mechanism** solves this limitation.

- In the next transformer layer, windows are shifted by a small offset.
- This causes windows to overlap with neighboring regions.
- As a result, information flows across windows between layers.

This mechanism enables cross-window communication and improves the model’s ability to capture **global context**.

---

## 4. Hierarchical Design
The Swin Transformer processes the image in multiple stages, forming a hierarchical representation.

### Stage 1: Patch Embedding
- The image is divided into non-overlapping patches.
- Each patch is embedded into a feature vector.

### Stage 2: Window-Based Self-Attention
- Patches are grouped into local windows.
- Self-attention is computed independently within each window.

### Stage 3: Shifted Window Self-Attention
- Windows are shifted in the next layer.
- Self-attention is recomputed, enabling interaction across neighboring windows.

### Stage 4: Hierarchical Feature Merging
- Features are progressively merged and downsampled.
- The model captures fine details at lower levels and more abstract, global features at higher levels.



## PyTorch Implementation

## Libraries


In [45]:
import torch
import torch.nn as nn
from timm.layers import DropPath

## Patch Partition

In [46]:
class PatchPartition(nn.Module):
    """
    Splits the input image into non-overlapping patches.

    Args:
        in_channels (int): Number of input channels (e.g., 3 for RGB images).
        emb_dim (int): Embedding dimension for each patch.
        patch_size (int): Size of each patch.
    """
    def __init__(self, in_channels, emb_dim, patch_size):
        super().__init__()
        self.patcher = nn.Conv2d(in_channels=in_channels, out_channels=emb_dim, kernel_size=patch_size, stride=patch_size)  # Similar to nn.Unfold + nn.Linear
        self.flatter = nn.Flatten(start_dim=-2, end_dim=-1)
        self.norm = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = self.patcher(x)
        x = self.flatter(x)
        x = x.permute(0, 2, 1)
        x = self.norm(x)
        return x

## Attention Mask Calculation

In [47]:
def compute_attn_mask(H, W, window_size, shift_size):
    img_mask = torch.zeros((1, H, W, 1))
    h_slices = (slice(0, window_size), slice(window_size, shift_size), slice(shift_size, None))
    w_slices = (slice(0, window_size), slice(window_size, shift_size), slice(shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    B, H, W, C = img_mask.shape
    img_mask = img_mask.view(B, H // window_size, window_size, W // window_size, window_size, C)
    mask_windows = img_mask.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    mask_windows = mask_windows.view(-1, window_size * window_size)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    return attn_mask

## Window Based Self-Attention

In [48]:
class WindowedAttention(nn.Module):
    """
    Performs multi-head self-attention within fixed-size windows.

    Args:
        emb_dim (int): Embedding dimension.
        num_heads (int): Number of attention heads.
        window_size (int): Size of the attention window.
        attn_dropout_p (float): Dropout probability for attention scores.
        output_dropout_p (float): Dropout probability after projection.
        shift_size (int): Size of the window shift.
    """
    def __init__(self, emb_dim, num_heads, window_size, attn_dropout_p, output_dropout_p, shift_size):
        super().__init__()
        self.shift_size = shift_size
        self.num_heads = num_heads
        self.window_size = window_size  # Store window_size for forward pass
        self.scale = (emb_dim // num_heads) ** -0.5
        self.qkv_proj = nn.Linear(in_features=emb_dim, out_features=emb_dim * 3)
        self.attn_dropout = nn.Dropout(p=attn_dropout_p)
        self.softmax = nn.Softmax(dim=-1)
        self.output_projection = nn.Linear(in_features=emb_dim, out_features=emb_dim)
        self.output_dropout = nn.Dropout(p=output_dropout_p)

        # FIXED: Store relative position bias table without unsqueeze
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(size=(num_heads, (2 * window_size - 1), (2 * window_size - 1))),
            requires_grad=True
        )

        # FIXED: Create Unfold as a proper module
        self.bias_unfold = nn.Unfold(kernel_size=(window_size, window_size), stride=1)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x, mask=None):
        B, N, C = x.shape
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        # FIXED: Ensure relative_position_bias is on the same device as input
        # Add unsqueeze here instead of in __init__
        bias_table = self.relative_position_bias_table.unsqueeze(1)
        relative_position_bias = self.bias_unfold(bias_table).flip(dims=(1,)).T.permute(2, 0, 1)

        # Ensure it's on the same device as attention
        relative_position_bias = relative_position_bias.to(attn.device)

        attn += relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            # FIXED: Move mask to the same device as attn
            mask = mask.to(attn.device)
            attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_dropout(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.output_projection(x)
        x = self.output_dropout(x)
        return x

## Patch Merging

In [49]:
class PatchMerging(nn.Module):
    """
    Merges adjacent patches to reduce spatial resolution and increases the embedding dimension.

    Args:
        hw (tuple): Height and width of the feature map.
        emb_dim (int): Embedding dimension.
    """
    def __init__(self, hw, emb_dim):
        super().__init__()
        self.H, self.W = hw[0], hw[1]
        self.unfold = nn.Unfold(kernel_size=(2, 2), stride=2)
        self.norm = nn.LayerNorm(4 * emb_dim)
        self.proj = nn.Linear(in_features=4 * emb_dim, out_features=2 * emb_dim)

    def forward(self, x):
        B, N, C = x.shape
        x = x.view(B, self.H, self.W, C).permute(0, 3, 1, 2)
        x = self.unfold(x).permute(0, 2, 1)
        x = self.norm(x)  # Different from PatchPartition, LayerNorm is before projection as in official implementation
        x = self.proj(x)
        return x

In [50]:
class SwinTransformerLayer(nn.Module):
    """
    Implements Two Successive Swin Transformer Blocks.

    Args:
        hw (tuple): Height and width of the feature map.
        window_size (int): Size of the attention window.
        num_heads (int): Number of attention heads.
        emb_dim (int): Embedding dimension.
        shift_size (int): Size of window shift for shifted window based self-attention.
        mlp_expansion (int): Expansion ratio in MLP.
        output_dropout_p (float): Dropout probability after attention output.
        mlp_drop_p (float): Dropout probability inside MLP.
        drop_path_p (float): Stochastic depth probability.
    """
    def __init__(self, hw, window_size, num_heads, emb_dim, shift_size, mlp_expansion, output_dropout_p, mlp_drop_p, drop_path_p):
        super().__init__()
        self.shift_size = shift_size
        self.window_size = window_size
        self.H, self.W = hw[0], hw[1]
        if (self.H <= self.window_size) or (self.W <= self.window_size):
            self.shift_size = 0
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.norm3 = nn.LayerNorm(emb_dim)
        self.norm4 = nn.LayerNorm(emb_dim)
        self.w_msa = WindowedAttention(emb_dim=emb_dim, num_heads=num_heads, window_size=window_size, output_dropout_p=output_dropout_p, attn_dropout_p=0, shift_size=shift_size)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=emb_dim, out_features=emb_dim * mlp_expansion),
            nn.GELU(),
            nn.Dropout(p=mlp_drop_p),
            nn.Linear(in_features=emb_dim * mlp_expansion, out_features=emb_dim),
            nn.Dropout(p=mlp_drop_p),
        )
        self.drop_path = DropPath(drop_prob=drop_path_p) if drop_path_p > 0.0 else nn.Identity()
        if self.shift_size > 0:
            self.attn_mask = compute_attn_mask(self.H, self.W, window_size, self.shift_size)

    def forward(self, x):
        B, N, C = x.shape
        residual = x
        x = self.norm1(x)

        x = x.reshape(B, self.H, self.W, C)
        x_windows = x.view(B, self.H // self.window_size, self.window_size, self.W // self.window_size, self.window_size, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size * self.window_size, C)

        x = self.w_msa(x_windows, mask=None)  # First self-attention is regular window based self-attention

        x = x.view(B, self.H // self.window_size, self.W // self.window_size, self.window_size, self.window_size, C).permute(0, 1, 3, 2, 4, 5)
        x = x.reshape(B, N, C)

        x = residual + self.drop_path(x)
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = residual + self.drop_path(x)

        residual = x
        x = self.norm3(x)

        x = x.reshape(B, self.H, self.W, C)
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))  # cyclic shift
            x_windows = shifted_x.view(B, self.H // self.window_size, self.window_size, self.W // self.window_size, self.window_size, C)
            x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size * self.window_size, C)
            x = self.w_msa(x_windows, mask=self.attn_mask)  # Shifted window based self-attention
            x = x.view(B, self.H // self.window_size, self.W // self.window_size, self.window_size, self.window_size, C).permute(0, 1, 3, 2, 4, 5)
            shifted_x = x.reshape(B, self.H, self.W, C)
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)).reshape(B, N, C)  # reverse cyclic shift
        else:
            x_windows = x.view(B, self.H // self.window_size, self.window_size, self.W // self.window_size, self.window_size, C)
            x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size * self.window_size, C)
            x = self.w_msa(x_windows, mask=None)
            x = x.view(B, self.H // self.window_size, self.W // self.window_size, self.window_size, self.window_size, C).permute(0, 1, 3, 2, 4, 5)
            x = x.reshape(B, N, C)
        x = residual + self.drop_path(x)
        residual = x
        x = self.norm4(x)
        x = self.mlp(x)
        x = residual + self.drop_path(x)
        return x

## Swin Transformer Block

In [51]:
class SwinTransformerBlock(nn.Module):
    """
    Composes multiple SwinTransformerLayer blocks without PatchMerging.

    Args:
        hw (tuple): Height and width of the feature map.
        window_size (int): Size of attention window.
        num_heads (int): Number of attention heads.
        emb_dim (int): Embedding dimension.
        shift_size (int): Size of the window shift.
        n_layers (int): Number of SwinTransformerLayer layers (should be even).
        output_dropout_p (float): Dropout probability after attention output.
        mlp_drop_p (float): Dropout probability inside MLP.
        mlp_expansion (int): Expansion ratio in MLP.
        drop_path_p (List[float]): List of stochastic depth probabilities per layer.
    """
    def __init__(self, hw, window_size, num_heads, emb_dim, shift_size, n_layers, output_dropout_p, mlp_drop_p, mlp_expansion, drop_path_p):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(n_layers//2):
            layer = SwinTransformerLayer(hw=hw, window_size=window_size, num_heads=num_heads, emb_dim=emb_dim,
                                         shift_size=shift_size, output_dropout_p=output_dropout_p, mlp_drop_p=mlp_drop_p, mlp_expansion=mlp_expansion,
                                         drop_path_p=drop_path_p[i])
            self.layers.append(layer)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

## Swin Transformer (Putting It All Together)

In [52]:
class SwinTransformer(nn.Module):
    """
    The full Swin Transformer model for image classification.

    Args:
        img_size (tuple of int): Size of the input image (height, width).
        in_channels (int): Number of input channels (e.g., 3 for RGB).
        emb_dims (list of int): Embedding dimensions for each stage.
        patch_size (int): Patch size to split the input image into non-overlapping patches.
        depth (int): Number of Swin Transformer stages (e.g., 4).
        num_classes (int): Number of output classes for classification.
        n_layers (list of int): Number of layers (SwinTransformerLayer blocks) for each stage.
        window_size (int): Size of the attention window.
        num_heads (list of int): Number of attention heads for each stage.
        output_dropout_p (float): Dropout probability after attention output.
        mlp_drop_p (list of float): Dropout probabilities for the MLP layers in each stage.
        mlp_expansion (int): Expansion ratio in MLP.
        drop_path_p (float): Maximum stochastic depth probability; linearly scaled across layers.
    """
    def __init__(self, img_size=(224, 224), in_channels=3, emb_dims=[96, 192, 384, 768], patch_size=4, depth=4,
                 num_classes=10, n_layers=[2, 2, 6, 2], window_size=7, num_heads=[3, 6, 12, 24],
                 output_dropout_p=0, mlp_drop_p=[0, 0, 0, 0], mlp_expansion=4, drop_path_p=0.1):
        super().__init__()
        img_h, img_w = img_size[0], img_size[1]

        # FIXED: Calculate drop_path_rates for all layers (removed //2)
        drop_path_rates = torch.linspace(0, drop_path_p, sum(n_layers)).tolist()

        self.layers = nn.ModuleList()
        for i in range(depth):
            if i == 0:
                patch_partition = PatchPartition(in_channels=in_channels, emb_dim=emb_dims[i], patch_size=patch_size)
                swin_block = SwinTransformerBlock(
                    hw=(img_h//(2 ** (i+2)), img_w//(2 ** (i+2))),
                    window_size=window_size,
                    num_heads=num_heads[i],
                    emb_dim=emb_dims[i],
                    shift_size=window_size // 2,
                    n_layers=n_layers[i],
                    output_dropout_p=output_dropout_p,
                    mlp_drop_p=mlp_drop_p[i],
                    mlp_expansion=mlp_expansion,
                    # FIXED: Correct indexing for drop_path_rates
                    drop_path_p=drop_path_rates[:n_layers[0]]
                )
                self.layers.append(patch_partition)
                self.layers.append(swin_block)
            else:
                patch_merging = PatchMerging(hw=(img_h//(2 ** (i+1)), img_w//(2 ** (i+1))), emb_dim=emb_dims[i-1])
                swin_block = SwinTransformerBlock(
                    hw=(img_h//(2 ** (i+2)), img_w//(2 ** (i+2))),
                    window_size=window_size,
                    num_heads=num_heads[i],
                    emb_dim=emb_dims[i],
                    shift_size=window_size // 2,
                    n_layers=n_layers[i],
                    output_dropout_p=output_dropout_p,
                    mlp_drop_p=mlp_drop_p[i],
                    mlp_expansion=mlp_expansion,
                    # FIXED: Correct indexing for drop_path_rates (removed //2)
                    drop_path_p=drop_path_rates[sum(n_layers[:i]):sum(n_layers[:i+1])]
                )
                self.layers.append(patch_merging)
                self.layers.append(swin_block)

        self.norm = nn.LayerNorm(emb_dims[-1])
        self.avg_pool = nn.AdaptiveAvgPool1d(output_size=1)
        self.fc = nn.Linear(in_features=emb_dims[-1], out_features=num_classes) if num_classes > 0 else nn.Identity()
        self.flatten = nn.Flatten(start_dim=-2, end_dim=-1)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.weight, 1.0)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x).transpose(1, 2)
        x = self.avg_pool(x)  # Instead of cls_token, the authors used AdaptiveAvgPool1d for classification
        x = self.flatten(x)
        x = self.fc(x)
        return x


In [53]:
model = SwinTransformer(

)

## Dataset and Preprocessing

In [54]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

In [55]:
def get_train_valid_loader(data_dir, batch_size, augment, random_seed, valid_size=0.1, shuffle=True):
    """Get training and validation data loaders for CIFAR-10"""

    # CIFAR-10 normalization
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010]
    )

    # Common transforms
    common_transform = [
        transforms.Resize((224, 224)),  # Resize to standard ResNet input
        transforms.ToTensor(),
        normalize
    ]

    # Training transforms with optional augmentation
    if augment:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224, padding=4),
            transforms.ToTensor(),
            normalize
        ])
    else:
        train_transform = transforms.Compose(common_transform)

    valid_transform = transforms.Compose(common_transform)

    # Load datasets
    train_dataset = datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=train_transform
    )
    valid_dataset = datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=valid_transform
    )

    # Create train/valid split
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler
    )

    return train_loader, valid_loader

In [56]:
def get_test_loader(data_dir, batch_size, shuffle=True):
    """Get test data loader for CIFAR-10"""

    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010]
    )

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalize
    ])

    dataset = datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform
    )

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle
    )

    return data_loader

In [60]:
# Configuration
data_dir = './data'
num_classes = 10  # CIFAR-10 has 10 classes
num_epochs = 5
batch_size = 64
learning_rate = 0.01

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load data
train_loader, valid_loader = get_train_valid_loader(
    data_dir=data_dir,
    batch_size=batch_size,
    augment=True,  # Enable data augmentation
    random_seed=1
)

test_loader = get_test_loader(data_dir=data_dir, batch_size=batch_size)

model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=learning_rate,
    weight_decay=0.0001,
    momentum=0.9
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

Using device: cuda


## Training

In [58]:
# Training loop
total_step = len(train_loader)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Print progress every 100 steps
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')

    # Epoch summary
    avg_loss = running_loss / total_step
    print(f'Epoch [{epoch+1}/{num_epochs}] - Average Loss: {avg_loss:.4f}')

    # Validation
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f'Validation Accuracy: {accuracy:.2f}%\n')

    # Step the scheduler
    scheduler.step()


Epoch [1/5], Step [100/704], Loss: 1.9691
Epoch [1/5], Step [200/704], Loss: 1.9275
Epoch [1/5], Step [300/704], Loss: 1.7681
Epoch [1/5], Step [400/704], Loss: 1.8913
Epoch [1/5], Step [500/704], Loss: 1.9825
Epoch [1/5], Step [600/704], Loss: 1.6147
Epoch [1/5], Step [700/704], Loss: 1.8082
Epoch [1/5] - Average Loss: 1.9038
Validation Accuracy: 37.30%

Epoch [2/5], Step [100/704], Loss: 1.7299
Epoch [2/5], Step [200/704], Loss: 1.4745
Epoch [2/5], Step [300/704], Loss: 1.5769
Epoch [2/5], Step [400/704], Loss: 1.3994
Epoch [2/5], Step [500/704], Loss: 1.7073
Epoch [2/5], Step [600/704], Loss: 1.4582
Epoch [2/5], Step [700/704], Loss: 1.2532
Epoch [2/5] - Average Loss: 1.5753
Validation Accuracy: 44.92%

Epoch [3/5], Step [100/704], Loss: 1.5126
Epoch [3/5], Step [200/704], Loss: 1.3159
Epoch [3/5], Step [300/704], Loss: 1.5085
Epoch [3/5], Step [400/704], Loss: 1.3140
Epoch [3/5], Step [500/704], Loss: 1.3775
Epoch [3/5], Step [600/704], Loss: 1.1526
Epoch [3/5], Step [700/704], Los