# PyTorch Tutorial

PyTorch is an open-source deep learning framework designed to simplify the process of building neural networks and machine learning models. With its dynamic computation graph, PyTorch allows developers to modify the network’s behavior in real-time, making it an excellent choice for both beginners and researchers.

In [1]:
! pip install torch torchvision

Collecting torch
  Downloading torch-2.7.1-cp311-cp311-win_amd64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.22.1-cp311-cp311-win_amd64.whl.metadata (6.1 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.7.1-cp311-cp311-win_amd64.whl (216.1 MB)
   ---------------------------------------- 0.0/216.1 MB ? eta -:--:--
    --------------------------------------- 5.2/216.1 MB 29.0 MB/s eta 0:00:08
   - -------------------------------------- 7.6/216.1 MB 19.6 MB/s eta 0:00:11
   - -----------


[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
## To installed in GPU 
## ! pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
## To installed in CPU

In [2]:
! pip3 install torch torchvision torchaudio

Collecting torchaudio
  Downloading torchaudio-2.7.1-cp311-cp311-win_amd64.whl.metadata (6.6 kB)
Downloading torchaudio-2.7.1-cp311-cp311-win_amd64.whl (2.5 MB)
   ---------------------------------------- 0.0/2.5 MB ? eta -:--:--
   ------------------------------------- -- 2.4/2.5 MB 16.8 MB/s eta 0:00:01
   ---------------------------------------- 2.5/2.5 MB 11.9 MB/s eta 0:00:00
Installing collected packages: torchaudio
Successfully installed torchaudio-2.7.1



[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


## Tensors in PyTorch
A tensor is a multi-dimensional array that is the fundamental data structure used in PyTorch (and many other machine learning frameworks).

In [3]:
import torch

tensor_1d = torch.tensor([1, 2, 3])
print("1D Tensor (Vector):")
print(tensor_1d)
print()

tensor_2d = torch.tensor([[1, 2], [3, 4]])
print("2D Tensor (Matrix):")
print(tensor_2d)
print()

random_tensor = torch.rand(2, 3)
print("Random Tensor (2x3):")
print(random_tensor)
print()

zeros_tensor = torch.zeros(2, 3)
print("Zeros Tensor (2x3):")
print(zeros_tensor)
print()

ones_tensor = torch.ones(2, 3)
print("Ones Tensor (2x3):")
print(ones_tensor)

1D Tensor (Vector):
tensor([1, 2, 3])

2D Tensor (Matrix):
tensor([[1, 2],
        [3, 4]])

Random Tensor (2x3):
tensor([[0.0185, 0.3767, 0.5523],
        [0.0106, 0.0924, 0.4034]])

Zeros Tensor (2x3):
tensor([[0., 0., 0.],
        [0., 0., 0.]])

Ones Tensor (2x3):
tensor([[1., 1., 1.],
        [1., 1., 1.]])


## Tensor Operations in PyTorch
PyTorch operations are essential for manipulating data efficiently, especially when preparing data for machine learning tasks.

Indexing: Indexing lets you retrieve specific elements or smaller sections from a larger tensor.\
Slicing: Slicing allows you to take out a portion of the tensor by specifying a range of rows or columns.\
Reshaping: Reshaping changes the shape or dimensions of a tensor without changing its actual data. This means you can reorganize the tensor into a different size while keeping all the original values intact.

In [4]:
import torch

tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])

element = tensor[1, 0]
print(f"Indexed Element (Row 1, Column 0): {element}")  
slice_tensor = tensor[:2, :]
print(f"Sliced Tensor (First two rows): \n{slice_tensor}")

reshaped_tensor = tensor.view(2, 3)
print(f"Reshaped Tensor (2x3): \n{reshaped_tensor}")

Indexed Element (Row 1, Column 0): 3
Sliced Tensor (First two rows): 
tensor([[1, 2],
        [3, 4]])
Reshaped Tensor (2x3): 
tensor([[1, 2, 3],
        [4, 5, 6]])


## Common Tensor Functions: Broadcasting, Matrix Multiplication, etc.
PyTorch offers a variety of common tensor functions that simplify complex operations.

Broadcasting allows for automatic expansion of dimensions to facilitate arithmetic operations on tensors of different shapes.\
Matrix multiplication enables efficient computations essential for neural network operations.

In [5]:
import torch

tensor_a = torch.tensor([[1, 2, 3], [4, 5, 6]])

tensor_b = torch.tensor([[10, 20, 30]]) 

broadcasted_result = tensor_a + tensor_b 
print(f"Broadcasted Addition Result: \n{broadcasted_result}")

matrix_multiplication_result = torch.matmul(tensor_a, tensor_a.T)
print(f"Matrix Multiplication Result (tensor_a * tensor_a^T): \n{matrix_multiplication_result}")

Broadcasted Addition Result: 
tensor([[11, 22, 33],
        [14, 25, 36]])
Matrix Multiplication Result (tensor_a * tensor_a^T): 
tensor([[14, 32],
        [32, 77]])


## GPU Acceleration with PyTorch
PyTorch facilitates GPU acceleration, enabling much faster computations, which is especially important in deep learning due to the extensive matrix operations involved. By transferring tensors to the GPU, you can significantly reduce training times and improve performance.

In [6]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

tensor_size = (10000, 10000)  
a = torch.randn(tensor_size, device=device)  
b = torch.randn(tensor_size, device=device)  

c = a + b  

print("Result shape (moved to CPU for printing):", c.cpu().shape)

print("Current GPU memory usage:")
print(f"Allocated: {torch.cuda.memory_allocated(device) / (1024 ** 2):.2f} MB")
print(f"Cached: {torch.cuda.memory_reserved(device) / (1024 ** 2):.2f} MB")

Using device: cpu
Result shape (moved to CPU for printing): torch.Size([10000, 10000])
Current GPU memory usage:
Allocated: 0.00 MB
Cached: 0.00 MB


## How PyTorch Is Used in LLMs

PyTorch is a deep learning framework (developed by Meta) used to:

1) Build, train, and deploy neural networks

2) Handle tensors (multi-dimensional arrays) and automatic differentiation

3) Enable GPU acceleration

4) Support dynamic computation graphs (flexible model building)

## How PyTorch Is Used in LLMs

| LLM Component               | PyTorch Feature Used                       |
| --------------------------- | ------------------------------------------ |
| Token Embedding Layer       | `nn.Embedding`                             |
| Transformer Block           | `nn.MultiheadAttention` + `nn.Linear`      |
| Feed-Forward Networks (FFN) | `nn.Linear` + activation functions         |
| Positional Encoding         | Custom tensor operations                   |
| Language Modeling Head      | `nn.Linear`                                |
| Loss Calculation            | `F.cross_entropy` or `nn.CrossEntropyLoss` |
| Optimizer                   | `torch.optim.Adam` or `SGD`                |


This is the core logic behind transformer-based LLMs like GPT.

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# 1. Sample text (you can replace with any text corpus)
text = "hello world, welcome to mini gpt model with pytorch"

# 2. Create vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

# 3. Encoding and decoding functions
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

# 4. Prepare dataset
data = torch.tensor(encode(text), dtype=torch.long)
block_size = 8  # context length

def get_batch():
    i = random.randint(0, len(data) - block_size - 1)
    x = data[i:i+block_size]
    y = data[i+1:i+block_size+1]
    return x.unsqueeze(1), y.unsqueeze(1)  # shape: (T, 1)

# 5. Define MiniGPT model
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads=2, batch_first=True)
        self.ffn = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)                # (B, T, E)
        x, _ = self.attn(x, x, x)        # (B, T, E)
        logits = self.ffn(x)             # (B, T, vocab_size)
        return logits

# 6. Instantiate model, optimizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MiniGPT(vocab_size, embed_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 7. Training loop
for step in range(500):
    x_batch, y_batch = get_batch()
    x_batch, y_batch = x_batch.transpose(0,1).to(device), y_batch.transpose(0,1).to(device)

    logits = model(x_batch)  # (B, T, vocab)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y_batch.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        print(f"Step {step}, Loss: {loss.item():.4f}")

# 8. Text Generation
def generate(start_char="h", length=50):
    model.eval()
    idx = torch.tensor([[stoi[start_char]]], dtype=torch.long).to(device)
    output = [stoi[start_char]]
    
    for _ in range(length):
        logits = model(idx)
        probs = F.softmax(logits[:, -1, :], dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        output.append(next_id.item())
        idx = torch.cat([idx, next_id], dim=1)

    return decode(output)

# 9. Generate text
print("\nGenerated text:")
print(generate("w"))


Step 0, Loss: 2.7496
Step 50, Loss: 0.5437
Step 100, Loss: 0.9032
Step 150, Loss: 0.2974
Step 200, Loss: 1.0577
Step 250, Loss: 1.2239
Step 300, Loss: 1.2115
Step 350, Loss: 0.9874
Step 400, Loss: 0.3451
Step 450, Loss: 0.7072

Generated text:
w w we wot we woe we wh wchcith weitwe wel we pyel 


| Concept                 | What It Does                                       |
| ----------------------- | -------------------------------------------------- |
| `nn.Embedding`          | Turns characters into dense vectors                |
| `nn.MultiheadAttention` | Captures relationships between characters          |
| `nn.Linear`             | Predicts next character from the current state     |
| `F.cross_entropy`       | Loss between predicted and true next characters    |
| `torch.multinomial`     | Sample next character from predicted probabilities |


Absolutely! Let's break down the **MiniGPT code** you ran — **line by line**, in **layman’s terms**, and help you understand:

---

# 🧠 What This Code Does:

We are building a **very tiny version of GPT**, called **MiniGPT**, using only:

* An embedding layer
* A single multi-head attention layer
* A linear (fully connected) layer

This model learns to **predict the next character** in a small piece of text like:

```
Input: "welcom"
Target: "elcome"
```

If trained well, you can use it to **generate new words** character by character.

---

## 🧱 Section-by-Section Explanation

---

### 🔹 1. Dataset & Vocabulary Creation

```python
text = "hello world, welcome to mini gpt model with pytorch"
chars = sorted(list(set(text)))  # Unique characters
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}  # char -> index
itos = {i: ch for ch, i in stoi.items()}      # index -> char
```

✅ What it does:

* Takes your small text
* Finds all **unique characters** (a-z, space, comma, etc.)
* Maps each character to a number (index)

🧠 Why?
Models don’t understand text directly — they understand **numbers**.

---

### 🔹 2. Encode & Prepare Data

```python
data = torch.tensor(encode(text), dtype=torch.long)
block_size = 8
```

✅ What it does:

* Encodes the entire string like `h=3`, `e=4`, `l=7`, ...
* Breaks it into small sequences of **8 characters (context window)**

### 🔹 3. Get Training Batch

```python
def get_batch():
    i = random.randint(0, len(data) - block_size - 1)
    x = data[i:i+block_size]
    y = data[i+1:i+block_size+1]
    return x.unsqueeze(1), y.unsqueeze(1)
```

✅ What it does:

* Picks a random segment like:
  `x = "welcome "`
  `y = "elcome t"`
* Adds a batch dimension to make it shape `(T, 1)`

---

### 🔹 4. Model Definition: `MiniGPT`

```python
self.embed = nn.Embedding(vocab_size, embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads=2, batch_first=True)
self.ffn = nn.Linear(embed_dim, vocab_size)
```

✅ What each layer does:

| Layer                | Purpose                                                                   |
| -------------------- | ------------------------------------------------------------------------- |
| `Embedding`          | Turns each character ID into a **vector** (dense representation)          |
| `MultiheadAttention` | Looks at context: "What other letters are important for this prediction?" |
| `Linear`             | Converts vector into vocab size to predict next letter                    |

---

### 🔹 5. Forward Pass (Prediction)

```python
x = self.embed(x)                # Shape becomes (Batch, Time, Embedding)
x, _ = self.attn(x, x, x)        # Apply attention to the input
logits = self.ffn(x)             # Predict next character
```

🧠 At this stage:

* Each character becomes a **vector**
* Attention helps the model "focus" on important letters
* Output is a prediction for each next character

---

### 🔹 6. Training the Model

```python
for step in range(500):
    logits = model(x_batch)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y_batch.view(-1))
    ...
```

✅ Training Goal:

* Predict the next character given the current ones
* Minimize the difference between **predicted** and **actual** next characters
* Improve every step

---

### 🔹 7. Text Generation

```python
def generate(start_char="h", length=50):
    ...
```

✅ What it does:

* Starts with one character (e.g., "w")
* Predicts next one → adds to input → predicts next → keeps going
* You get something like:

```
w → we → wel → welc → welco → welcom → welcome ...
```

---

## 🧪 Why Your Output Is Repetitive?

```text
w w we wot we woe we wh wchcith weitwe wel we pyel
```

➡️ **Reason**:

* Training data is **too small**
* Only trained for **500 steps**
* Model capacity is **tiny** (32-dimensional vectors)
* The model hasn't learned complex patterns yet

---

## ✅ What You Can Do to Improve

1. **Train longer**

   * Increase `range(500)` to `range(2000)` or more
2. **Use more text**

   * Feed it a larger dataset (e.g., from a file or public dataset)
3. **Stack more layers**

   * Add another attention + linear layer
4. **Add positional encoding**

   * So it knows where in the sequence it is

---

## 🧠 Summary

| Part         | What It Does                         |
| ------------ | ------------------------------------ |
| `Embedding`  | Converts text to vector              |
| `Attention`  | Helps model focus on context         |
| `Linear`     | Predicts next character              |
| `Training`   | Learns by reducing prediction error  |
| `Generate()` | Creates text one character at a time |

---



## How PyTorch Helps in Memory Management of LLMs
Training or running inference on LLMs can require huge memory (RAM/VRAM), especially for:

1) Large vocabularies

2) Long input sequences

3) Deep transformer layers

| Optimization Technique                 | How PyTorch Helps                                                  |
| -------------------------------------- | ------------------------------------------------------------------ |
| ✅ **Mixed Precision (float16)**        | Use `torch.cuda.amp` to reduce memory by 2x                        |
| ✅ **Gradient Checkpointing**           | Recompute activations to reduce memory usage                       |
| ✅ **Quantization (int8, 4-bit)**       | Reduce model size 4x–8x via `torch.quantization` or `bitsandbytes` |
| ✅ **TorchScript**                      | Compile model for efficient inference                              |
| ✅ **Model Parallelism**                | Spread large LLM across multiple GPUs                              |
| ✅ **Zero Redundancy Optimizer (ZeRO)** | Shard optimizer states across GPUs                                 |


PyTorch provides several features and tools to optimize memory usage:

### 1. Mixed Precision Training (float16 / bfloat16)
Reduces memory usage by half using lower-precision floating point:



In [10]:
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F

# assume x_batch and y_batch are input/output tensors
x_batch, y_batch = get_batch()
x_batch, y_batch = x_batch.transpose(0,1).to(device), y_batch.transpose(0,1).to(device)

scaler = GradScaler()

with autocast():  # Enable mixed precision
    output = model(x_batch)                       # Pass actual input tensor
    loss = F.cross_entropy(output.view(-1, vocab_size), y_batch.view(-1))  # Compute loss

scaler.scale(loss).backward()  # Backward pass in FP16
scaler.step(optimizer)         # Optimizer step
scaler.update()                # Update scaler for next iteration
optimizer.zero_grad()          # Reset gradients



  scaler = GradScaler()
  with autocast():  # Enable mixed precision


###  2. Gradient Checkpointing
Recomputes some forward activations instead of storing them — trade compute for memory.


In [13]:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

# Dummy model block
class MyBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return torch.relu(self.linear(x))

# Wrap it in a model
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = MyBlock()

    def forward(self, x):
        # checkpoint the block
        return checkpoint(self.block, x)

# Initialize model and input
model = MyModel()
input_tensor = torch.randn(4, 10, requires_grad=True)  # Batch of 4

# Forward pass using checkpoint
output = model(input_tensor)
print(output)


tensor([[0.0131, 0.0000, 0.0000, 0.0000, 0.0000, 0.2348, 0.0000, 0.0000, 0.7775,
         0.0000],
        [0.0000, 0.4335, 0.0000, 0.0000, 0.0000, 0.4982, 0.6340, 1.8039, 0.0000,
         0.6003],
        [0.2243, 1.0454, 1.7345, 0.5063, 0.0000, 0.0000, 0.0905, 0.0000, 0.4688,
         0.0000],
        [0.0750, 0.9282, 0.0000, 0.0000, 1.0636, 0.0000, 1.1022, 1.5848, 0.0000,
         0.4528]], grad_fn=<CheckpointFunctionBackward>)


### 3. Model Quantization (int8 or 4-bit)
Converts model weights from 32-bit → 8-bit (or even 4-bit):

In [14]:
from torch.quantization import quantize_dynamic

model_quant = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)


| Optimization Technique                 | Description                                                                                  | How PyTorch Helps                                                                                                                                                | Benefit                                             |
| -------------------------------------- | -------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------- |
| ✅ **Mixed Precision (float16)**        | Use half-precision (16-bit) instead of full precision (32-bit) where safe                    | `torch.cuda.amp` enables automatic mixed precision training. It handles when to use `float16` and when to fall back to `float32`                                 | ⚡ Faster computation and 🧠 \~2x less memory        |
| ✅ **Gradient Checkpointing**           | Save memory by not storing all intermediate activations; recompute them during backward pass | `torch.utils.checkpoint` allows you to wrap parts of your model to delay computation                                                                             | 🧠 30–50% memory savings during training            |
| ✅ **Quantization (int8, 4-bit)**       | Represent model weights with lower precision (e.g., int8 or even 4-bit)                      | PyTorch supports post-training quantization via `torch.quantization`. Also supports 4-bit models via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) | 🚀 Smaller models, faster inference                 |
| ✅ **TorchScript**                      | Compile PyTorch models to a static graph                                                     | Use `torch.jit.script()` or `torch.jit.trace()` to convert dynamic PyTorch into TorchScript                                                                      | 🚀 Faster inference, compatible with C++ deployment |
| ✅ **Model Parallelism**                | Split a large model across multiple GPUs when it can't fit into one                          | Manually partition layers or use libraries like `torch.distributed.pipeline.sync.Pipe`                                                                           | 🧠 Enables very large model training (e.g., GPT-3)  |
| ✅ **ZeRO (Zero Redundancy Optimizer)** | Shard model parameters, gradients, and optimizer states across GPUs                          | PyTorch DeepSpeed and `torch.distributed.fsdp` (Fully Sharded Data Parallel) implement ZeRO stages                                                               | 🧠 Train 10x–100x larger models across GPUs         |
| ✅ **Dynamic Graphs and Autograd**      | PyTorch's dynamic computation graph allows fine-grained control of memory and operations     | Built-in `autograd` engine for gradient computation; can manually delete unused tensors                                                                          | ⚙️ Custom memory-efficient training logic           |


### 1. TorchScript (for Inference Optimization) 
Converts your model to a static, optimized graph using torch.jit.

In [15]:
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)

# Instantiate and trace
model = SimpleModel()
example_input = torch.randn(1, 10)

# Convert to TorchScript
traced_model = torch.jit.trace(model, example_input)
traced_model.save("simple_model.pt")  # Save for production

# Load and run
loaded_model = torch.jit.load("simple_model.pt")
output = loaded_model(torch.randn(1, 10))
print("TorchScript Output:", output)


TorchScript Output: tensor([[-0.1812,  0.4482, -0.1455, -0.2589,  0.0545]],
       grad_fn=<AddmmBackward0>)


### 2. Model Parallelism (Manual GPU Splitting)
Split model layers across two GPUs (requires at least 2 GPUs).

In [None]:
import torch
import torch.nn as nn

# Ensure you have at least 2 GPUs
device0 = torch.device("cuda:0")
device1 = torch.device("cuda:1")

# Split model across devices
class SplitModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq1 = nn.Linear(10, 20).to(device0)
        self.seq2 = nn.Linear(20, 5).to(device1)

    def forward(self, x):
        x = x.to(device0)
        x = torch.relu(self.seq1(x))
        x = x.to(device1)
        return self.seq2(x)

model = SplitModel()
input_tensor = torch.randn(4, 10)
output = model(input_tensor)
print("Model Parallel Output:", output)


In [16]:
### 3. ZeRO (Zero Redundancy Optimizer) via

In [None]:
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import wrap

import os

def fsdp_main(rank, world_size):
    # 1. Initialize process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    # 2. Set device
    torch.cuda.set_device(rank)

    # 3. Model & wrap with FSDP
    class ToyModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(1000, 1000),
                nn.ReLU(),
                nn.Linear(1000, 1000)
            )

        def forward(self, x):
            return self.net(x)

    model = ToyModel().to(rank)
    fsdp_model = FSDP(model)  # Automatically shards params

    # 4. Training step
    optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-3)
    input_tensor = torch.randn(32, 1000).to(rank)
    target = torch.randn(32, 1000).to(rank)

    for step in range(5):
        output = fsdp_model(input_tensor)
        loss = nn.MSELoss()(output, target)
        loss.backward()
        optim.step()
        optim.zero_grad()
        print(f"[Rank {rank}] Step {step} Loss: {loss.item()}")

    dist.destroy_process_group()

# To launch: torch.multiprocessing.spawn
if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    import torch.multiprocessing as mp
    mp.spawn(fsdp_main, args=(world_size,), nprocs=world_size)
