In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import timeit
from jaxtyping import Float
from einops import rearrange, einsum, reduce
from typing import Iterable

In [None]:
x = torch.zeros(4,8)
print(x.dtype)
print(x.size())
print(x.numel())
print(x.element_size())

torch.float32
torch.Size([4, 8])
32
4


In [12]:
x=torch.zeros(4,8,dtype=torch.float16)
print(x.element_size())
x=torch.tensor([1e-8], dtype=torch.float16)
print(x)
x=torch.tensor([1e-8], dtype=torch.bfloat16)
print(x)
print(x.device)

2
tensor([0.], dtype=torch.float16)
tensor([1.0012e-08], dtype=torch.bfloat16)
cpu


In [22]:

print(torch.cuda.is_available())
print(torch.mps.is_available())
num_gpus = torch.mps.device_count()
print(num_gpus)
memory_allocated = torch.mps.current_allocated_memory()
print(memory_allocated)
x=torch.zeros(4,8,dtype=torch.float16)
y = x.to('mps')
print(y.device)
z = torch.zeros(32,32, device="mps")
new_memory_allocated = torch.mps.current_allocated_memory()
memory_used = new_memory_allocated - memory_allocated
print(memory_used)

False
True
1
512
mps:0
4096


In [26]:
x = torch.tensor([[0.,1,2,3],[4,5,6,7],[8,9,10,11],[12,13,14,15]])
print(x.stride(0))
print(x.stride(1))

4
1


In [37]:
x = torch.tensor([[1.,2,3],[4,5,6]])
y= x[0]
print(y)
print(x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr())

y = x[:, 1]
print(y)
print(x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr())

y = x.view(3,2)
print(y)
print(x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr())

y= x.transpose(1,0)
print(y)
print(x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr())

x[0][0] = 100
print(y[0][0])

tensor([1., 2., 3.])
True
tensor([2., 5.])
True
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])
True
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])
True
tensor(100.)


In [44]:
x = torch.tensor([[1.,2,3],[4,5,6]])
y = x.transpose(1,0)
print(y)
print(y.is_contiguous())
try:
    y.view(2,3)
except RuntimeError as e:
    print(str(e))

y=x.transpose(1,0).contiguous().view(2,3)
print(y)
print(x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr())

tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])
False
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
tensor([[1., 4., 2.],
        [5., 3., 6.]])
False


In [None]:
x = torch.tensor([1,4,9])
print(x.pow(2))
print(x.sqrt())
print(x.rsqrt())
print(x+x)
print(x*2)
print(x/0.5)

tensor([ 1, 16, 81])
tensor([1., 2., 3.])
tensor([1.0000, 0.5000, 0.3333])
tensor([ 2,  8, 18])
tensor([ 2,  8, 18])
tensor([ 2.,  8., 18.])


In [49]:
x = torch.ones(3,3).triu()
print(x)

tensor([[1., 1., 1.],
        [0., 1., 1.],
        [0., 0., 1.]])


In [9]:
x = torch.ones(16,32)
print(x.stride())
y = torch.ones(32,2)
print(y.stride())
z = x @ y
print(z.stride())
print(z.size())
x = torch.ones(4,8,16,32)
y = torch.ones(32,2)
z = x @ y
print(z.size())

(32, 1)
(2, 1)
(2, 1)
torch.Size([16, 2])
torch.Size([4, 8, 16, 2])


In [None]:
# test transpose
x = torch.tensor(
    [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]]
)
print(x.size())
print(x.stride())
y = x.transpose(1, -1)
print(y.size())
print(y.stride())

torch.Size([2, 2, 2, 2])
(8, 4, 2, 1)
torch.Size([2, 2, 2, 2])
(8, 1, 2, 4)


In [17]:
# einops
x = torch.ones(2, 2, 3)
print(x.stride())
y = torch.ones(2, 2, 3)
print(y.stride())
z = y.transpose(-2, -1)   
print(z.stride())
z = x @ y.transpose(-2, -1)

# einsum
x: Float[torch.Tensor, "batch seq1 hidden"] = torch.ones(2, 3, 4)
y: Float[torch.Tensor, "batch seq2 hidden"] = torch.ones(2, 3, 4)

#z = x @ y.transpose(-2,-1)

z = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")
print(z.size())

# reduce
y = x.mean(-1)
print(y.stride())
print(y)

y = reduce(x, "... hidden -> ...", "sum")
print(y.stride())
print(y)

# rearrange
x: Float[torch.Tensor, "batch seq total_hidden"] = torch.ones(2, 3, 8)
w: Float[torch.Tensor, "hidden1 hidden2"] = torch.ones(4, 4)

x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads = 2)
print(x.stride())
print(x)

x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")
print(x.stride())
print(x)

x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")
print(x.stride())
print(x)


(6, 3, 1)
(6, 3, 1)
(6, 1, 3)
torch.Size([2, 3, 3])
(3, 1)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
(3, 1)
tensor([[4., 4., 4.],
        [4., 4., 4.]])
(24, 8, 4, 1)
tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.]]],


        [[[1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])
(24, 8, 4, 1)
tensor([[[[4., 4., 4., 4.],
          [4., 4., 4., 4.]],

         [[4., 4., 4., 4.],
          [4., 4., 4., 4.]],

         [[4., 4., 4., 4.],
          [4., 4., 4., 4.]]],


        [[[4., 4., 4., 4.],
          [4., 4., 4., 4.]],

         [[4., 4., 4., 4.],
          [4., 4., 4., 4.]],

         [[4., 4., 4., 4.],
          [4., 4., 4., 4.]]]])
(24, 8, 1)
tensor([[[4., 4., 4., 4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4., 4., 4., 4.],
         [4., 4.,

In [19]:
def time_matmul(a: torch.Tensor, b: torch.Tensor) -> float:
    """Return the number of seconds required to perform `a @ b`."""
    # Wait until previous CUDA threads are done
    if torch.mps.is_available():
        torch.mps.synchronize()
    def run():
        # Perform the operation
        a @ b
        # Wait until CUDA threads are done
        if torch.mps.is_available():
            torch.mps.synchronize()
    # Time the operation `num_trials` times
    num_trials = 5
    total_time = timeit.timeit(run, number=num_trials)
    return total_time / num_trials


In [None]:
B = 16384  # Number of points
D = 32768  # Dimension
K = 8192   # Number of outputs
device = torch.device("mps")
x = torch.ones(B, D, device=device)
w = torch.randn(D, K, device=device)

actual_num_flops = 2 * B * D * K

# actual_time = time_matmul(x, w)
# actual_flop_per_sec = actual_num_flops / actual_time 
# print(actual_flop_per_sec)

x.to(torch.bfloat16)
w.to(torch.bfloat16)
actual_time = time_matmul(x, w)
actual_flop_per_sec = actual_num_flops / actual_time 
print(actual_flop_per_sec)


4450460492667.393
4537461470868.2295


In [25]:
x = torch.tensor([1.,2,3])
w = torch.tensor([1.,1,1], requires_grad=True)
pred_y = x @ w
loss = 0.5 * (pred_y -5).pow(2)
loss.backward()
print(w.grad)

tensor([1., 2., 3.])


In [32]:
input_dim = 16384
output_dim = 32

w = nn.Parameter(torch.randn(input_dim, output_dim))
print(isinstance(w, torch.Tensor))
print(type(w.data))
x = nn.Parameter(torch.randn(input_dim))
output = x @ w
print(output)
print(output.size())

w = nn.Parameter(torch.randn(input_dim, output_dim) / np.sqrt(input_dim))
output = x @ w
print(output)
print(output.size())

w = nn.Parameter(nn.init.trunc_normal_(torch.empty(input_dim, output_dim), std=1 / np.sqrt(input_dim), a=-3, b=3))
output = x @ w
print(output)
print(output.size())

True
<class 'torch.Tensor'>
tensor([ -73.4144, -200.6748,  134.5749,  -77.8689,  148.2784,   45.0297,
         168.0362,  -79.4507,   59.3149, -299.4882,  -95.3090,  234.8435,
          35.7946,  100.4075,   25.6909,  208.0293,  -97.5634,  287.4185,
        -135.6534,   25.7079, -146.1654, -183.7879,  326.0525,   -1.9754,
          97.9761, -105.9708,  -26.5679,  105.8718,   -7.9916,  102.5728,
          88.2670,  -10.6167], grad_fn=<SqueezeBackward4>)
torch.Size([32])
tensor([-0.0130, -0.3221, -0.1343, -2.4081,  0.2725, -0.5096,  0.9437,  0.3971,
        -0.8904,  2.7399,  0.7183,  0.1857, -0.0626,  0.6610,  1.0374,  0.2503,
         0.5053, -2.1089,  0.0930, -0.0555, -0.1158,  1.4827,  2.1302,  0.5563,
        -0.7038,  0.7435, -1.1427, -1.4490,  0.4605, -0.0352, -1.4512, -1.7150],
       grad_fn=<SqueezeBackward4>)
torch.Size([32])
tensor([ 1.4710,  1.0047,  0.7078, -2.1375, -0.3954,  1.2550,  0.0476, -1.3834,
         0.3395, -0.3414,  0.3584, -0.2962,  0.3027,  1.8976, -1.5819,  1

In [16]:
def get_device() -> torch.device:
    """Try to use the GPU if possible, otherwise, use CPU."""
    if torch.mps.is_available():
        return torch.device(f"mps")
    else:
        return torch.device("cpu")

In [14]:
class Linear(nn.Module):
    """Simple linear layer."""
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(input_dim, output_dim) / np.sqrt(input_dim))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.weight

class Cruncher(nn.Module):
    def __init__(self, dim: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([
            Linear(dim, dim)
            for i in range(num_layers)
        ])
        self.final = Linear(dim, 1)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply linear layers
        B, D = x.size()
        for layer in self.layers:
            x = layer(x)
        # Apply final head
        x = self.final(x)
        assert x.size() == torch.Size([B, 1])
        # Remove the last dimension
        x = x.squeeze(-1)
        assert x.size() == torch.Size([B])
        return x

In [5]:
def get_num_parameters(model: nn.Module) -> int:
    return sum(param.numel() for param in model.parameters())

In [None]:
# custom model
D = 64
num_layers = 2
model = Cruncher(dim=D, num_layers=num_layers).to(get_device())

param_sizes = [
        (name, param.numel())
        for name, param in model.state_dict().items()
    ]

print(param_sizes)

num_parameters = get_num_parameters(model)
print(num_parameters)
print(D*D*2+D)

device = get_device()
model.to(device)

# run
B = 8
x = torch.randn(B, D, device=device)
y = model(x)
print(y.size())

[('layers.0.weight', 4096), ('layers.1.weight', 4096), ('final.weight', 64)]
8256
8256
torch.Size([8])
torch.Size([8])


In [7]:
# random
seed = 0
torch.manual_seed(seed)

import numpy as np
np.random.seed(0)

import random
random.seed(seed)

In [8]:
class SGD(torch.optim.Optimizer):
    def __init__(self, params: Iterable[nn.Parameter], lr: float = 0.01):
        super(SGD, self).__init__(params, dict(lr=lr))
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                grad = p.grad.data
                p.data -= lr * grad

In [9]:
class AdaGrad(torch.optim.Optimizer):
    def __init__(self, params: Iterable[nn.Parameter], lr: float = 0.01):
        super(AdaGrad, self).__init__(params, dict(lr=lr))
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                # Optimizer state
                state = self.state[p]
                grad = p.grad.data
                # Get squared gradients g2 = sum_{i<t} g_i^2
                g2 = state.get("g2", torch.zeros_like(grad))
                # Update optimizer state
                g2 += torch.square(grad)
                state["g2"] = g2
                # Update parameters
                p.data -= lr * grad / torch.sqrt(g2 + 1e-5)


In [17]:
# optimizer
B = 2
D = 4
num_layers = 2
model = Cruncher(dim=D, num_layers=num_layers).to(get_device())

optimizer = AdaGrad(model.parameters(), lr=0.01)
state = model.state_dict()  # @inspect state

x = torch.randn(B, D, device=get_device())
y = torch.tensor([4., 5.], device=get_device())
pred_y = model(x)
loss = F.mse_loss(input=pred_y, target=y)
loss.backward()

# Take a step
optimizer.step()
state = model.state_dict()  # @inspect state

# Free up the memory (optional)
optimizer.zero_grad(set_to_none=True)

In [18]:
def get_num_parameters(model: nn.Module) -> int:
    return sum(param.numel() for param in model.parameters())

In [20]:
# Memory
num_parameters = (D * D * num_layers) + D
print(num_parameters)
print(get_num_parameters(model))

num_activations = B * D * num_layers
print(num_activations)

num_gradients = num_parameters

num_optimizer_states = num_parameters 

total_memory = 4 * (num_parameters + num_activations + num_gradients + num_optimizer_states)

print(total_memory)

# compute
flops = 6 * B * num_parameters 
print(flops)

36
36
16
496
432


In [None]:
A = np.array([[[0,1],[2,3],[4,5]],[[6,7],[8,9],[10,11]]])
print(A)
At = A.T
print(At)


[[[0 1]
  [2 3]]

 [[4 5]
  [6 7]]]
[[[0 4]
  [2 6]]

 [[1 5]
  [3 7]]]


In [3]:
x1 = 2.0 / 1000
print(f"{x1:.18f}")

x2 = 1 + (1/10000 ) - (1 - 1/1000)
print(f"{x2:.18f}")

0.002000000000000000
0.001099999999999990
