# Tests

## Block tests

In [1]:
from residual_blocks import ClassicBottleneck, InvertedBottleneck, ConvNeXtBlock
import torch
from architecture import Stem, Body, Head

### Classic bottleneck block

In [2]:
classic_bottleneck_block = ClassicBottleneck(10,5,3)



In [3]:
# Batch, exterior channels, x, y
test_input = torch.rand([3,10,50,50])
classic_bottleneck_block(test_input).size()

torch.Size([3, 10, 50, 50])

### Inverse bottleneck block

In [4]:
inverted_bottleneck_block = InvertedBottleneck(10,10)

In [5]:
test_input = torch.rand([3,10,50,50])
inverted_bottleneck_block(test_input).size()

torch.Size([3, 10, 50, 50])

In [6]:
resid_output = inverted_bottleneck_block(test_input)
inverted_bottleneck_block.use_residual = False
resid_output_2 = inverted_bottleneck_block(test_input) + test_input

all((resid_output == resid_output_2).reshape(-1))

True

In [7]:
inverted_bottleneck_block_2 = InvertedBottleneck(10,20, kernel_size=7,stride=4, use_residual=False)

In [8]:
inverted_bottleneck_block_2(test_input).size()

torch.Size([3, 20, 13, 13])

### ConvNeXt

In [9]:
convnext_block = ConvNeXtBlock(10,10)

In [10]:
convnext_block(test_input).size()

torch.Size([3, 10, 50, 50])

## Parameter Counts

In [11]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

In [12]:
fake_input = torch.randn((2,3,224,224)) #batch_size 2, 3 input channels
stem = Stem(out_channels=80, stem_type='patchify')
stem(fake_input)
body = Body()
out = body(stem(fake_input))
head = Head(out.size(1))
head(out) # just to check it runs
total_params = count_params(body)+count_params(head)+count_params(stem)
target_achieved = 'achieved' if total_params < 21000000 else 'exceeded'
print(f'total parameters = {total_params}. Target of < 21m {target_achieved}')

total parameters = 19157641. Target of < 21m achieved


## Data

In [13]:
from data import DefaultLoader, calculate_mean_std

In [14]:
train_dataloader = DefaultLoader.load_train(batch_size=64)

In [15]:
for img, label in train_dataloader:
    pass

In [16]:
img.size()

torch.Size([2, 3, 224, 224])

In [17]:
img.mean(2).size()

torch.Size([2, 3, 224])

In [18]:
img.view(2, img.size(1), -1).size()

torch.Size([2, 3, 50176])

In [19]:
img.mean([0,2,3])

tensor([0.6353, 0.5434, 0.5690])

In [20]:
(img**2).sum([0,2,3])

tensor([49324.6758, 35516.7891, 38423.3477])

In [21]:
calculate_mean_std(train_dataloader)

(tensor([0.6353, 0.5434, 0.5690]), tensor([0.2912, 0.2392, 0.2291]))