In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Define Ring Attention and Blockwise Transformer classes

class RingAttention(nn.Module):
    def __init__(self, input_size, num_heads=8, block_size=16):
        super(RingAttention, self).__init__()
        self.input_size = input_size
        self.num_heads = num_heads
        self.block_size = block_size

        self.query_projection = nn.Linear(input_size, input_size)
        self.key_projection = nn.Linear(input_size, input_size)
        self.value_projection = nn.Linear(input_size, input_size)
        self.output_projection = nn.Linear(input_size, input_size)

    def forward(self, x):
        batch_size, seq_len, input_size = x.size()
        assert input_size == self.input_size

        # Reshape input into blocks
        x = x.view(batch_size, -1, self.block_size, input_size)

        # Apply projections
        queries = self.query_projection(x)  # [batch_size, num_blocks, block_size, input_size]
        keys = self.key_projection(x)  # [batch_size, num_blocks, block_size, input_size]
        values = self.value_projection(x)  # [batch_size, num_blocks, block_size, input_size]

        # Compute attention scores
        attention_scores = torch.einsum('bijk,bilk->bijl', queries, keys) / (input_size ** 0.5)  # [batch_size, num_blocks, block_size, block_size]
        attention_weights = F.softmax(attention_scores, dim=-1)  # [batch_size, num_blocks, block_size, block_size]

        # Apply attention to values
        attended_values = torch.einsum('bijl,bilk->bijk', attention_weights, values)  # [batch_size, num_blocks, block_size, input_size]

        # Reshape back to original shape
        attended_values = attended_values.view(batch_size, seq_len, input_size)

        # Apply output projection
        outputs = self.output_projection(attended_values)

        return outputs

class BlockwiseTransformer(nn.Module):
    def __init__(self, input_size, num_layers=4, num_heads=8, block_size=16):
        super(BlockwiseTransformer, self).__init__()
        self.input_size = input_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.block_size = block_size

        # Define Ring Attention layers
        RA_Layers = []
        for _ in range(num_layers):
          RA_Layers.append(RingAttention(input_size, num_heads, block_size))
        self.attention_layers = nn.ModuleList(RA_Layers)

        # self.attention_layers = nn.ModuleList([RingAttention(input_size, num_heads, block_size) for _ in range(num_layers)])

        # Feedforward layer
        self.feedforward = nn.Sequential(
            nn.Linear(input_size, 4 * input_size),
            nn.ReLU(),
            nn.Linear(4 * input_size, input_size)
        )

    def forward(self, x):
        for layer in self.attention_layers:
            x = x + layer(x)  # Residual connection
            x = F.layer_norm(x, normalized_shape=x.size()[1:])  # Layer normalization
            x = self.feedforward(x)  # Feedforward layer
        return x

# Define SyntheticDataset class
class SyntheticDataset(Dataset):
    def __init__(self, num_samples, seq_len, input_size):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.input_size = input_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random sequence tensor
        sequence = torch.randn(self.seq_len, self.input_size)

        # Determine label based on some criteria (e.g., sum of the sequence)
        label = torch.tensor(1 if sequence.sum() > 0 else 0, dtype=torch.long)

        return sequence, label

# Instantiate the dataset and dataloader
num_samples = 1000  # Number of samples in the dataset
seq_len = 1024      # Length of each sequence
input_size = 512    # Dimensionality of each element in the sequence
batch_size = 32     # Batch size for training
dataset = SyntheticDataset(num_samples, seq_len, input_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Instantiate the model
model = BlockwiseTransformer(input_size=input_size, num_layers=4, num_heads=8, block_size=16)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)

        # Compute loss
        loss = criterion(outputs.mean(dim=1), targets.squeeze())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if (batch_idx + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {total_loss / len(dataloader):.4f}")


Epoch [1/10], Batch [10/32], Loss: 0.6545
Epoch [1/10], Batch [20/32], Loss: 1.1160
Epoch [1/10], Batch [30/32], Loss: 0.6245
Epoch [1/10], Average Loss: 1.1188
Epoch [2/10], Batch [10/32], Loss: 0.7041
Epoch [2/10], Batch [20/32], Loss: 0.7914
Epoch [2/10], Batch [30/32], Loss: 0.7351
Epoch [2/10], Average Loss: 0.7848
Epoch [3/10], Batch [10/32], Loss: 0.7019
Epoch [3/10], Batch [20/32], Loss: 0.6059
Epoch [3/10], Batch [30/32], Loss: 0.7423
Epoch [3/10], Average Loss: 0.7463
Epoch [4/10], Batch [10/32], Loss: 0.7776
Epoch [4/10], Batch [20/32], Loss: 0.7830
Epoch [4/10], Batch [30/32], Loss: 0.6873
Epoch [4/10], Average Loss: 0.7228
Epoch [5/10], Batch [10/32], Loss: 0.7369
Epoch [5/10], Batch [20/32], Loss: 0.6925
Epoch [5/10], Batch [30/32], Loss: 0.7179
Epoch [5/10], Average Loss: 0.7131
Epoch [6/10], Batch [10/32], Loss: 0.7483
Epoch [6/10], Batch [20/32], Loss: 0.7060
Epoch [6/10], Batch [30/32], Loss: 0.6869
Epoch [6/10], Average Loss: 0.7188
Epoch [7/10], Batch [10/32], Loss: