In [1]:
import torch
import sys
print(sys.version)
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("cuDNN version:", torch.backends.cudnn.version())
    print("GPU Name:", torch.cuda.get_device_name(0))


3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:38:46) [MSC v.1929 64 bit (AMD64)]
PyTorch version: 2.3.1+cu118
CUDA available: True
CUDA version: 11.8
cuDNN version: 8700
GPU Name: NVIDIA GeForce RTX 4060 Laptop GPU


In [2]:
import torch


if torch.cuda.is_available():
    print("GPU is available!")
    device = torch.device("cuda")  
else:
    print("GPU is not available, using CPU.")
    device = torch.device("cpu")  


x = torch.tensor([1.0, 2.0, 3.0], device=device)
y = torch.tensor([4.0, 5.0, 6.0], device=device)


z = x + y
print(f"x: {x}")
print(f"y: {y}")
print(f"z: {z} (on {device})")


if torch.cuda.is_available():
    print("Current GPU:", torch.cuda.get_device_name(0))


GPU is available!
x: tensor([1., 2., 3.], device='cuda:0')
y: tensor([4., 5., 6.], device='cuda:0')
z: tensor([5., 7., 9.], device='cuda:0') (on cuda)
Current GPU: NVIDIA GeForce RTX 4060 Laptop GPU


In [3]:
import torch
import torch.nn as nn


batch_size = 4
seq_len = 10
embed_dim = 64  
num_heads = 8   

#  MultiheadAttention
multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

#  (batch_size, seq_len, embed_dim)
query = torch.rand(batch_size, seq_len, embed_dim)  # Q
key = torch.rand(batch_size, seq_len, embed_dim)    # K
value = torch.rand(batch_size, seq_len, embed_dim)  # V


output, attn_weights = multihead_attn(query, key, value)

print("Output shape:", output.shape)               # (batch_size, seq_len, embed_dim)
print("Attention weights shape:", attn_weights.shape)  # (batch_size, num_heads, seq_len, seq_len)



Output shape: torch.Size([4, 10, 64])
Attention weights shape: torch.Size([4, 10, 10])


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MyModel, self).__init__()
        # 
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

    def forward(self, x):

        B, H, F = x.shape


        x = x.view(B, H, F)
        print(f"After reshape: {x.shape}")

        # MultiheadAttention expects inputs in the form (batch_size, seq_len, embed_dim)

        output, attn_weights = self.attention(x, x, x)
        print(f"Output shape: {output.shape}")
        print(f"Attention weights shape: {attn_weights.shape}")

        return output


B, H, F = 32, 20, 64  # Batch size, Sequence length, Embedding dimension
num_heads = 8
x = torch.randn(B, H, F)  

model = MyModel(embed_dim=F, num_heads=num_heads)
output = model(x)


After reshape: torch.Size([32, 20, 64])
Output shape: torch.Size([32, 20, 64])
Attention weights shape: torch.Size([32, 20, 20])


In [5]:
import torch
import torch.nn.functional as F

# Example
preds = torch.randn(32, 5)  # Shape: (B=32, N=5)
print("Before softmax:")
print(preds)

# Apply softmax along the last dimension (-1)
preds = F.softmax(preds, dim=-1)
print("After softmax:")
print(preds)


Before softmax:
tensor([[-2.3936, -0.7790, -0.3967, -0.8505,  0.0509],
        [ 0.2480, -0.5122, -0.7802,  1.2352,  1.5804],
        [-0.5679, -1.5157,  0.4523,  0.9482,  0.0849],
        [-0.1517,  1.1859, -0.1958, -0.1723,  0.1180],
        [ 0.5610,  0.9170, -0.6280,  0.8250,  0.5983],
        [-0.3074, -1.7504, -0.1745,  0.4617, -1.0566],
        [ 0.9257,  1.2592, -2.3153, -0.7699, -0.5144],
        [-0.4343, -0.7214, -1.2467,  0.6687, -0.1391],
        [-0.2405, -0.7276,  0.9328, -1.1414, -0.1175],
        [ 1.5703, -0.2622,  0.8426,  0.7560,  0.6770],
        [-0.7653,  1.6279,  0.9796, -0.8113, -0.4846],
        [ 0.4852,  1.1089, -1.0230, -0.9203,  0.8567],
        [-0.4231, -0.4528,  0.7385, -0.1007, -0.6254],
        [-0.8582,  0.0659, -0.3369,  0.1153,  1.1516],
        [ 1.5190,  0.0485,  0.9633,  1.1710,  0.4995],
        [-0.3589, -0.6403, -0.2113,  0.9878,  1.0411],
        [-0.2826, -0.5999, -2.0255,  0.2208,  0.6662],
        [-1.5142, -2.6993, -0.3244,  0.8340, -0.7