In [1]:
from typing import (
    Callable, List
)

import matplotlib.pyplot as plt
import numpy as np

In [2]:
class Schedule_Factory:
    def __init__(self, epochs):
        self.epochs = epochs
    
    def constant_schedule(self, learning_rate: float) -> Callable[[int], float]:
        return lambda x: learning_rate
    
    def linear_schedule(self, starting_rate:float, ending_rate:float) -> Callable[[int], float]:
        def lin_sched_fn(epoch:int):
            lrs = np.linspace(starting_rate, ending_rate, self.epochs)
            return lrs[epoch]
        
        return lin_sched_fn
    
    def step_schedule(self, transitions:List, rates:List):
        lrs = np.empty(self.epochs)
        
        lrs[:transitions[0]] = rates[0]
        
        for i, transition in enumerate(transitions[1:]):
            lrs[transitions[i]:transitions[i+1]] = rates[i+1]
        
        print(lrs)
        
        def step_sched_fn(epoch):
            return lrs[epoch]
        
        return step_sched_fn

In [3]:
factory = Schedule_Factory(100)

In [4]:
const_fn = factory.constant_schedule(0.001)
linear_fn = factory.linear_schedule(0.01, 0.001)
step_fn = factory.step_schedule([25,75],[0.01,0.005,0.001])

[1.00000000e-002 1.00000000e-002 1.00000000e-002 1.00000000e-002
 1.00000000e-002 1.00000000e-002 1.00000000e-002 1.00000000e-002
 1.00000000e-002 1.00000000e-002 1.00000000e-002 1.00000000e-002
 1.00000000e-002 1.00000000e-002 1.00000000e-002 1.00000000e-002
 1.00000000e-002 1.00000000e-002 1.00000000e-002 1.00000000e-002
 1.00000000e-002 1.00000000e-002 1.00000000e-002 1.00000000e-002
 1.00000000e-002 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.00000000e-003 5.00000000e-003 5.00000000e-003
 5.00000000e-003 5.000000

In [5]:
const_fn(25)

0.001

In [6]:
linear_fn(5)

0.009545454545454546

In [7]:
linear_fn(95)

0.0013636363636363637

In [8]:
step_fn(24)

0.01

In [11]:
import torch

import torchvision.models as models
from torchvision import transforms

In [12]:
from tqdm import tqdm
from PACS_Dataloader.PACS_Dataloader import PACS_Dataset

In [13]:
HOLDOUT_DOMAIN = "photo"
split_name = "train"

In [14]:
pacs_dl = torch.utils.data.DataLoader(
        PACS_Dataset(HOLDOUT_DOMAIN, split_name), batch_size=30, shuffle=True
)

In [15]:
batch_data, batch_labels = next(iter(pacs_dl))

In [25]:
batch_data.dtype

torch.float64

In [16]:
rnet = models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [17]:
type(rnet)

torchvision.models.resnet.ResNet

In [18]:
print(rnet)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [29]:
rnet(batch_data.float()).shape

torch.Size([30, 1000])