In [1]:
import torch

    

In [18]:

from cs336_basics.transformerStuff import softmax

exampleLogits = torch.randn(2, 5)
exampleIndices = torch.tensor([0, 2])

probs = softmax(exampleLogits, dim=-1)
probs

tensor([[0.1587, 0.0488, 0.4874, 0.2383, 0.0668],
        [0.2220, 0.1276, 0.0781, 0.2395, 0.3328]])

In [38]:
import einx
trueLogits = einx.get_at("b [v], (b [idx]) -> b 1", exampleLogits, exampleIndices)

In [32]:
eLogits = torch.exp(exampleLogits - exampleLogits.max(dim=-1, keepdim=True).values)
logSumExp = torch.log(eLogits.sum(axis=-1, keepdim=True))
logSumExp

tensor([[0.7186],
        [1.1003]])

In [44]:
torch.mean(- (trueLogits - logSumExp))

tensor(1.2175)

In [60]:
def cross_entropy(inputs, targets):
    """
        Given a tensor of inputs and targets, compute the average cross-entropy
        loss across examples.
    
        Args:
            inputs (Float[Tensor, "batch_size vocab_size"]): inputs[i][j] is the
                unnormalized logit of jth class for the ith example.
            targets (Int[Tensor, "batch_size"]): Tensor of shape (batch_size,) with the index of the correct class.
                Each value must be between 0 and `num_classes - 1`.
    
        Returns:
            Float[Tensor, ""]: The average cross-entropy loss across examples.
    """
    inputs = einx.rearrange("... v -> (...) v", inputs)
    targets = einx.rearrange("... -> (...)", targets)
    inputs = inputs - inputs.max(dim=-1, keepdim=True).values
    trueLogits = einx.get_at("b [v], (b [idx]) -> b 1", inputs, targets)
    eLogits = torch.exp(inputs)
    logSumExp = torch.log(eLogits.sum(axis=-1, keepdim=True))
    return torch.mean(- (trueLogits - logSumExp))

In [61]:
cross_entropy(exampleLogits, exampleIndices)

tensor(2.1950)

In [62]:
inputs =  torch.tensor([[0.1088, 0.1060, 0.6683, 0.5131, 0.0645],
               [0.4538, 0.6852, 0.2520, 0.3792, 0.2675],
               [0.4578, 0.3357, 0.6384, 0.0481, 0.5612],
               [0.9639, 0.8864, 0.1585, 0.3038, 0.0350],
               [0.3356, 0.9013, 0.7052, 0.8294, 0.8334],
              [0.6333, 0.4434, 0.1428, 0.5739, 0.3810],
              [0.9476, 0.5917, 0.7037, 0.2987, 0.6208],
               [0.8541, 0.1803, 0.2054, 0.4775, 0.8199]])
targets = torch.tensor([1, 0, 2, 2, 4, 1, 4, 0])
cross_entropy(inputs, targets)

tensor(1.6095)

In [65]:
inputs = torch.randn(3, 5, 7)
targets = torch.randint(0, 7, size=(3, 5))
cross_entropy(inputs, targets)

tensor(2.0685)

In [66]:
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math

In [103]:
class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        print(type(params))
        super().__init__(params, defaults)
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate.
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 0) # Get iteration number from the state, or initial value.
                grad = p.grad.data # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
                state["t"] = t + 1 # Increment iteration number.
        return loss
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1)
for t in range(10):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

<class 'list'>
26.3189640045166
25.2767333984375
24.566858291625977
24.002788543701172
23.525131225585938
23.106184005737305
22.73040008544922
22.388046264648438
22.072553634643555
21.77923011779785


In [72]:
print(opt.state)

defaultdict(<class 'dict'>, {})


In [74]:
opt.param_groups

[{'params': [Parameter containing:
   tensor([[ -0.0617,  -5.3531,  -0.1866,  -2.1943,  -6.7153,  -4.1213,   4.0880,
              4.8780,   0.1285,   1.5604],
           [ -9.5234,   6.2102,  -4.3096,   3.5231,   2.8335,   3.1488,   3.9583,
              9.0442,  -0.5850,   3.7255],
           [ -0.6307,  -0.3854,  -7.6866,   0.1314,  -1.0857,   2.3122,  -8.5276,
             -2.1388,   5.3930,  -0.6250],
           [  8.4559,   2.1054,   6.6357,  -0.4007,  -1.8352,   1.6471,   4.2291,
              0.7674,  -1.5360,   4.4736],
           [ -3.4156,  -4.5099,   0.9973,   2.2794,   1.7119,   5.8206,   2.9674,
             -2.1727,  -1.0410,  -0.6734],
           [  4.3551,   5.6291,   1.9736,  -4.2587,  -5.4546,  -8.3573,  -8.0197,
             -0.7031,   1.0147,   1.3910],
           [  9.5710,  -9.4592,   0.4313,  -5.2994,  -1.6013,  -5.0713,  -0.9703,
              3.6844,   6.7415,   7.0759],
           [  0.5346,   6.6267,   2.0775,   1.7978,   2.5313,  -5.4249,  -2.2080,
        

In [102]:
class AdamW(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas = (0.9, 0.999), weight_decay = 0.01, eps = 10e-8):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr, 
                    "beta_1" : betas[0], 
                    "beta_2" : betas[1],
                   "weight_decay" : weight_decay,
                   "epsilon" : eps}
        super().__init__(params, defaults)
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate.
            beta_1 = group['beta_1']
            beta_2 = group['beta_2']
            weight_decay = group['weight_decay']
            epsilon = group['epsilon']
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 1)
                m = state.get("firstMoment", torch.zeros_like(p.data)) # Get iteration number from the state, or initial value.
                v = state.get("secondMoment", torch.zeros_like(p.data)) # Get iteration number from the state, or initial value.

                grad = p.grad.data # Get the gradient of loss with respect to p.
                m = beta_1 * m + (1 - beta_1) * grad
                v = beta_2 * v + (1 - beta_2) * grad ** 2
                adjustedLR = lr * ((1 - beta_2 ** t) ** 0.5) / (1 - beta_1 ** t)
                p.data -= adjustedLR * m / (v ** 0.5 + epsilon)
                p.data -= lr * weight_decay * p.data

                state['t'] = t + 1
                state['firstMoment'] = m
                state['secondMoment'] = v
                
        return loss
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
weights2 = torch.nn.Parameter(5 * torch.randn((10, 10)))

opt = AdamW([weights, weights2], lr=1)
for t in range(10):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

26.091079711914062
19.099563598632812
13.836668968200684
10.049921035766602
7.290708065032959
5.216639995574951
3.6774072647094727
2.5980215072631836
1.8952690362930298
1.4653306007385254


In [91]:
def lrScheduler(t, alphaMax, alphaMin, Tw, Tc):
    if t < Tw:
        return t / Tw * alphaMax
    elif t <= Tc:
        cosPortion = math.cos((t - Tw) / (Tc - Tw) * math.pi)
        return alphaMin + 0.5 * (1 + cosPortion) * (alphaMax - alphaMin)
    else:
        return alphaMin

24.787813186645508
defaultdict(<class 'dict'>, {})
