# production

In [137]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import datetime
import copy
import cv2
import statistics
import os.path as path
import os
import glob
import pandas as pd
import random
import shutil
import ctypes  # An included library with Python install.

from spectral import *
from tempfile import mkdtemp
from PIL import Image
from matplotlib import pyplot as plt
from functools import partial
from dataclasses import dataclass
from collections import OrderedDict 

In [127]:
class Conv2dAuto(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding =  (self.kernel_size[0] // 2, self.kernel_size[1] // 2) # dynamic add padding based on the kernel_size
        
conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)  # create an instance of Conv2dAuto with kernel 3 


class ResidualBlock(nn.Module):
    """set the basic format of a block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels, self.out_channels =  in_channels, out_channels
        # blocks in block
        self.blocks = nn.Identity()
        # in channels != out channels there is a sortcut to the residual that will fit outchannels
        self.shortcut = nn.Identity()   
    
    def forward(self, x):
        residual = x
        if self.should_apply_shortcut: residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        return x
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.out_channels

        
class ResNetResidualBlock(ResidualBlock):
    """set the shortcut"""
    def __init__(self, in_channels, out_channels, expansion=1, downsampling=1, conv=conv3x3, *args, **kwargs):
        super().__init__(in_channels, out_channels)
        self.expansion, self.downsampling, self.conv = expansion, downsampling, conv  #conv = a conv2d with auto padding we created
        # the shorcat is a sequence of convolution and batch normalazation
        self.shortcut = nn.Sequential(OrderedDict(
        {
            'conv' : nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
                      stride=self.downsampling, bias=False),   # expandind channels, downsampling by stride
            'bn' : nn.BatchNorm2d(self.expanded_channels)
            
        })) if self.should_apply_shortcut else None   # is channels miss match
        
        
    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.expanded_channels

    
def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
    return nn.Sequential(OrderedDict({'conv': conv(in_channels, out_channels, *args, **kwargs), 
                          'bn': nn.BatchNorm2d(out_channels) }))

 
class ResNetBasicBlock(ResNetResidualBlock):
    """set the block to a sequence of 2 conv_bn which are a sequence of conv + bn"""
    expansion = 1
    def __init__(self, in_channels, out_channels, activation=nn.ReLU, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling),
            activation(),  # first convbn can downsample 2d
            conv_bn(self.out_channels, self.expanded_channels, conv=self.conv, bias=False),  # second convbn - no downsample, possible channel expansion
        )

        
class ResNetLayer(nn.Module):
    """stack n layers one on each other, first layer can downsample if in!=out channels"""
    def __init__(self, in_channels, out_channels, block=ResNetBasicBlock, n=1, *args, **kwargs):
        super().__init__()
        # 'We perform downsampling directly by convolutional layers that have a stride of 2.'
        downsampling = 2 if in_channels != out_channels else 1
        # the block in the layer is a sequence of blocks
        self.blocks = nn.Sequential(
             block(in_channels , out_channels, *args, **kwargs, downsampling=downsampling), # first block with downsample
            *[block(out_channels * block.expansion, 
                    out_channels, downsampling=1, *args, **kwargs) for _ in range(n - 1)]  # all other blocks no downsampling, possible expantion
        )

    def forward(self, x):
        x = self.blocks(x)
        return x
 
class ResnetDecoder(nn.Module):
    """
    This class represents the tail of ResNet. It performs a global pooling and maps the output to the
    correct class by using a fully connected layer.
    """
    def __init__(self, in_features, n_classes):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d((1, 1))    # average along spacial area H * W
        self.decoder = nn.Linear(in_features, n_classes) #fully connected with n_classes out neurons
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.avg(x)
        print(x.shape)
        x = x.view(x.size(0), -1)   # flat the tensor
        print(x.shape)
        x = self.decoder(x)
        print(x.shape)
        x = self.sigmoid(x)
        return x
    
    
class ResNet(nn.Module):
    def __init__(self, in_channels, out_channels, stride, n_classes):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1,self.stride)
        self.layer = ResNetLayer(self.out_channels, self.out_channels, block=ResNetBasicBlock, n=5)
        self.decoder = ResnetDecoder(self.out_channels, n_classes)
        #self.encoder = ResNetEncoder(in_channels, *args, **kwargs)
        #self.decoder = ResnetDecoder(self.encoder.blocks[-1].blocks[-1].expanded_channels, n_classes)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.layer(x)
        #x = self.encoder(x)
        x = self.decoder(x)
        return x
    


In [128]:
"""run the model"""
d = 2
input_channel = 730
f = input_channel // d
n_classes = 10
stride_first_layer = 2
ResNetModel = ResNet(input_channel ,f, stride_first_layer, n_classes)
ResNetModel



ResNet(
  (conv1): Conv2d(730, 365, kernel_size=(1, 1), stride=(2, 2))
  (layer): ResNetLayer(
    (blocks): Sequential(
      (0): ResNetBasicBlock(
        (blocks): Sequential(
          (0): Sequential(
            (conv): Conv2dAuto(365, 365, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(365, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): ReLU()
          (2): Sequential(
            (conv): Conv2dAuto(365, 365, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(365, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (shortcut): None
      )
      (1): ResNetBasicBlock(
        (blocks): Sequential(
          (0): Sequential(
            (conv): Conv2dAuto(365, 365, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(365, eps=1e-05, momentum=0.1, affine=True, track_ru

In [129]:
x = torch.ones((730,300,600))
x2 = ResNetModel(x[None, ...])

RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:62] data. DefaultCPUAllocator: not enough memory: you tried to allocate %dGB. Buy new RAM!0


# try

In [None]:
np.random.seed(0)

# global parameters
storage_path = "D:/My Drive/StoragePath"
result_path = storage_path + "/ExpResults/tal_exp_results"
results_path = result_path + "/results"
data_path = result_path + "/no_note_background"
data_set_path = storage_path + "/Datasets/tal_datesets/corn_data_set"
labels_df = pd.read_csv(f"{data_set_path}/phenotyping.csv") 

params = get_params()
if params['isTuning']:
    study = optuna.create_study()
    study.optimize(Objective(data_path, tunning_params, params), n_trials=15)
    print(f'finished tunning params. best params = {study.best_params} with accuarcy: {study.best_value}')
    params = visualize_study_set_best_params(study, tunning_params['visualizing_params'], params)

images_and_labels = get_data(data_path,
                             params['size'])  # return dict with labels: data['images'] and data['labels']
""" HOG part """
class_names = class_ind_to_name(images_and_labels["labels"], class_indices)
df = image_2_vec(images_and_labels, class_names, params["orient"], params["ppc"],
                                        params["cpb"])

""" SVM part"""
# inserting train_test_split to train_svm func
X_train, X_test, y_train, y_test = split_data(df)

acc, y_pred, y_test, decision_mat = train_svm(X_train, X_test, y_train, y_test, params, class_names)

print(f"""
>>> model acc = {acc}
""")

labels = df.iloc[:, -1:]  # last column in transformed df before split is label column

two_biggest_errors_viz(decision_mat, y_test, labels, class_indices, data_path)

In [126]:
def get_params():
    """
    define hyper-parameters
    :return: dict of hyper-parameters
    """
    return {
        'd': 2,
        'n_classes': 10,
        'stride_first_layer': 2,
        'isTuning': False
    }

# f = input_channel // d


In [None]:
def get_tunning_params():
    parameters = {
        's': [80, 92, 100, 110, 120],  # in tuning part 1 - [32, 64, 80, 92, 100, 110, 120, 128, 160, 180, 200],
        'orient': list(range(13, 20)),  # in tuning part 1 - list(range(1, 20)),
        'cpb': [3, 3],  # in tuning part 1 -list(range(1, 5)),
        'visualizing_params': ["s", "orient"]
    }
    return parameters

In [298]:
def get_data(path, size):
    """
    :param path: data path
    :param size: width and length of images (tuple)
    :return:  dictionary of np.arrays of images and labels
    """
    img_list = []
    label_list = []
    dir_list = os.listdir(path)

    print("Getting data from source...")

    for i, img in enumerate(dir_list):
        if i > 5:
            break
        if ".img" in img:
            continue
        image_data = envi.open(f'{path}/{img}')
        resized = resize(image_data, size)
        """get mean and std from top 2 rows for background"""
        mean = np.median(image_data[:, 0:1, :], axis=(0, 1))  # median for every band
        black_std = np.std(image_data[:, 0:1, :], axis=(0, 1))
        final_image = pad_img(resized, size, mean, black_std)
        img_list.append(final_image)
        labels = get_labels(img)
        label_list.append(labels)

    print("Getting data done.")
    return {'images': np.array(img_list), 'labels': np.array(label_list)}


def resize(image_data, size):
    """resizing by the larger aspect ratio"""
    (w_target, h_target) = size
    (h_origin, w_origin) = image_data.shape[:2]
    h_ratio = h_target / h_origin
    w_ratio = w_target / w_origin
    if h_ratio > w_ratio:
        ratio = w_ratio
    else: 
        ratio = h_ratio
    dim = (int(w_origin * ratio), int(h_origin * ratio)) 
    resized_img = np.zeros((dim[1], dim[0], image_data.shape[2]))
    print(dim)
    for b in range(image_data.shape[2]):
        resized_band = cv2.resize(image_data[:, :, b], dim, interpolation = cv2.INTER_AREA)
        resized_img[:, :, b] = resized_band
    return resized_img


def pad_img(resized, size, mean, black_std):
    """
    pad the image with normal dist of background
    taking mean and std from 2 top rows    
    """
    (w_target, h_target) = size
    (h_origin, w_origin) = resized.shape[:2]
    final_image = np.zeros((h_target, w_target, resized.shape[2]))
    for band in range(resized.shape[2]):
        padded_img = np.random.normal(mean[band], black_std[band], (h_target, w_target))
        # compute center offset
        xx = (w_target - w_origin) // 2
        yy = (h_target - h_origin) // 2
        # enter image
        padded_img[yy:yy + h_origin, xx:xx + w_origin] = resized[:, :, band]
        final_image[:, :, band] = padded_img
    return final_image
        

def get_labels(img):
    date = img.split('_')[2]
    date = date.replace('-', '')
    date = int(date[2:])
    #img_name = img.split('_')[1]
    #img_name = img_name.replace('plot', '')
    if ("plot" in img_name) == False:
        print("found image not of plot in no_note_background")
        return False
    plot_num = int(img_name.split('plot')[1].split("_")[0])
    if plot_num > 0:
        # print(date)
        print(plot_num)
        labels = labels_df[(labels_df['SampleDate'] == date) & (labels_df['plot'] == plot_num)].iloc[0]    
    return labels

In [237]:
def resize(image_data, size):
    """resizing bt the larger aspect ratio"""
    (w_target, h_target) = size
    (h_origin, w_origin) = image_data.shape[:2]
    h_ratio = h_target / h_origin
    w_ratio = w_target / w_origin
    if h_ratio > w_ratio:
        ratio = w_ratio
    else: 
        ratio = h_ratio
    dim = (int(w_origin * ratio), int(h_origin * ratio)) 
    resized_img = np.zeros((dim[1], dim[0], image_data.shape[2]))
    print(dim)
    for b in range(image_data.shape[2]):
        resized_band = cv2.resize(image_data[:, :, b], dim, interpolation = cv2.INTER_AREA)
        resized_img[:, :, b] = resized_band
    return resized_img

In [199]:
image_data = envi.open("D:/My Drive/StoragePath/ExpResults/tal_exp_results/no_note_background/Corn_plot9_2019-12-24_08-28-25_no_note.hdr")

In [290]:
def get_labels(img):
    date = img.split('_')[2]
    date = date.replace('-', '')
    date = int(date[2:])
    #img_name = img.split('_')[1]
    #img_name = img_name.replace('plot', '')
    if ("plot" in img_name) == False:
        print("found image not of plot in no_note_background")
        return False
    plot_num = int(img_name.split('plot')[1].split("_")[0])
    if plot_num > 0:
        # print(date)
        print(plot_num)
        labels = labels_df[(labels_df['SampleDate'] == date) & (labels_df['plot'] == plot_num)].iloc[0]    
    return labels

In [299]:
images_and_labels = get_data(data_path,
                             (512,256))  # return dict with labels: data['images'] and data['labels']

Getting data from source...
(511, 172)


MemoryError: 

In [295]:
# global parameters
storage_path = "D:/My Drive/StoragePath"
result_path = storage_path + "/ExpResults/tal_exp_results"
results_path = result_path + "/results"
data_path = result_path + "/no_note_background"
data_set_path = storage_path + "/Datasets/tal_datesets/corn_data_set"
labels_df = pd.read_csv(f"{data_set_path}/phenotyping.csv") 

(11,)

Unnamed: 0,plot,SampleDate,necrosis,Burning,Bleaching,Chlorosis,Epinasty curling,Inhibited growth,Wilting,Disturbed apical bud-> abnormal,Y_cropped
0,1,191222,1,1,1,1,1,1,1,1,
1,2,191222,1,1,1,1,1,1,1,1,
2,3,191222,1,1,1,1,1,1,1,1,
3,4,191222,1,1,1,1,1,1,1,1,
4,5,191222,1,1,1,1,1,1,1,1,
5,6,191222,1,1,1,1,1,1,1,1,
6,7,191222,1,1,1,1,1,1,1,1,
7,8,191222,1,1,1,1,1,1,1,1,
8,9,191222,1,1,1,1,1,1,1,1,
9,10,191222,1,1,1,1,1,1,1,1,


In [284]:
labels = labels_df[(labels_df['SampleDate'] == 200101) & (labels_df['plot'] == 98)].iloc[0]  

In [285]:
labels

plot                                    98.0
SampleDate                          200101.0
necrosis                                 4.0
Burning                                  1.0
Bleaching                                1.0
Chlorosis                                3.0
Epinasty curling                         1.0
Inhibited growth                         4.0
Wilting                                  5.0
Disturbed apical bud-> abnormal          1.0
Y_cropped                                NaN
Name: 692, dtype: float64

In [264]:
def pad_img(resized, size, mean, black_std):
    """
    pad the image with normal dist of background
    taking mean and std from 2 top rows    
    """
    (w_target, h_target) = size
    (h_origin, w_origin) = resized.shape[:2]
    final_image = np.zeros((h_target, w_target, resized.shape[2]))
    for band in range(resized.shape[2]):
        padded_img = np.random.normal(mean[band], black_std[band], (h_target, w_target))
        # compute center offset
        xx = (w_target - w_origin) // 2
        yy = (h_target - h_origin) // 2
        # enter image
        padded_img[yy:yy + h_origin, xx:xx + w_origin] = resized[:, :, band]
        final_image[:, :, band] = padded_img
    return final_image

In [265]:
mean = np.median(image_data[:, 0:1, :], axis=(0, 1))  # median for every band
black_std = np.std(image_data[:, 0:1, :], axis=(0, 1))
final_umage = pad_img(img, size, mean, black_std)

In [253]:
mean = np.median(image_data[:, 0:1, :], axis=(0, 1))  # median for every band
black_std = np.std(image_data[:, 0:1, :], axis=(0, 1))
resized = img
(w_target, h_target) = size
(h_origin, w_origin) = resized.shape[:2]
final_image = np.zeros((h_target, w_target, resized.shape[2]))
for band in range(resized.shape[2]):
    padded_img = np.random.normal(mean[band], black_std[band], (h_target, w_target))
    # compute center offset
    xx = (w_target - w_origin) // 2
    yy = (h_target - h_origin) // 2
    # enter image
    padded_img[yy:yy + h_origin, xx:xx + w_origin] = resized[:, :, band]
    final_image[:, :, band] = padded_img
    return final_image

In [267]:
save_rgb(result_path + "/RGB.png", final_umage, [430, 179 + 20, 108])

In [226]:
cv2.imwrite(result_path +'/plot10_wave123.tif',resized.reshape((100,400,1)) ) #image_data[:, :, 123]

True

In [266]:
final_umage.shape

(256, 512, 730)

In [225]:
resized.reshape((100,400,1)).shape

(100, 400, 1)

In [103]:
class ResnetDecoder(nn.Module):
    """
    This class represents the tail of ResNet. It performs a global pooling and maps the output to the
    correct class by using a fully connected layer.
    """
    def __init__(self, in_features, n_classes):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.decoder = nn.Linear(in_features, n_classes)

    def forward(self, x):
        x = self.avg(x)
        print(x.shape)
        x = x.view(x.size(0), -1)
        print(x.shape)
        x = self.decoder(x)
        print(x.shape)
        return x

In [25]:

import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from dataclasses import dataclass
from collections import OrderedDict    
        

In [2]:
class Conv2dAuto(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding =  (self.kernel_size[0] // 2, self.kernel_size[1] // 2) # dynamic add padding based on the kernel_size
        
conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)      
conv1x1 = partial(Conv2dAuto, kernel_size=1, bias=False)        

In [91]:

del conv

NameError: name 'conv' is not defined

In [4]:

class ResidualBlock(nn.Module):
    """set the basic format of a block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels, self.out_channels =  in_channels, out_channels
        # blocks in block
        self.blocks = nn.Identity()
        # in channels != out channels there is a sortcut to the residual that will fit outchannels
        self.shortcut = nn.Identity()   
    
    def forward(self, x):
        residual = x
        if self.should_apply_shortcut: residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        return x
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.out_channels

In [5]:
ResidualBlock(32, 64)

ResidualBlock(
  (blocks): Identity()
  (shortcut): Identity()
)

In [7]:

class ResNetResidualBlock(ResidualBlock):
    """set the shortcut"""
    def __init__(self, in_channels, out_channels, expansion=1, downsampling=1, conv=conv3x3, *args, **kwargs):
        super().__init__(in_channels, out_channels)
        self.expansion, self.downsampling, self.conv = expansion, downsampling, conv  #conv = a conv2d with auto padding we created
        # the shorcat is a sequence of convolution and batch normalazation
        self.shortcut = nn.Sequential(OrderedDict(
        {
            'conv' : nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
                      stride=self.downsampling, bias=False),   # expandind channels, downsampling by stride
            'bn' : nn.BatchNorm2d(self.expanded_channels)
            
        })) if self.should_apply_shortcut else None   # is channels miss match
        
        
    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.expanded_channels

In [8]:
ResNetResidualBlock(32, 64)

ResNetResidualBlock(
  (blocks): Identity()
  (shortcut): Sequential(
    (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [9]:

def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
    """function that returns a sequence of conv & bn"""
    return nn.Sequential(OrderedDict({'conv': conv(in_channels, out_channels, *args, **kwargs), 
                          'bn': nn.BatchNorm2d(out_channels) }))

In [10]:
conv_bn(3, 3, nn.Conv2d, kernel_size=3)

Sequential(
  (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [11]:

class ResNetBasicBlock(ResNetResidualBlock):
    """set the block to a sequence of 2 conv_bn which are a sequence of conv + bn"""
    expansion = 1
    def __init__(self, in_channels, out_channels, activation=nn.ReLU, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling),
            activation(),  # first convbn can downsample 2d
            conv_bn(self.out_channels, self.expanded_channels, conv=self.conv, bias=False),  # second convbn - no downsample, possible channel expansion
        )

In [12]:
block = ResNetBasicBlock(32, 64)   # create instance of the model
print(block)

ResNetBasicBlock(
  (blocks): Sequential(
    (0): Sequential(
      (conv): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ReLU()
    (2): Sequential(
      (conv): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (shortcut): Sequential(
    (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


ResNetBasicBlock(
  (blocks): Sequential(
    (0): Sequential(
      (conv): Conv2dAuto(730, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ReLU()
    (2): Sequential(
      (conv): Conv2dAuto(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (shortcut): Sequential(
    (conv): Conv2d(730, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


In [78]:

class ResNetLayer(nn.Module):
    """"""
    def __init__(self, in_channels, out_channels, block=ResNetBasicBlock, n=1, *args, **kwargs):
        super().__init__()
        # 'We perform downsampling directly by convolutional layers that have a stride of 2.'
        downsampling = 2 if in_channels != out_channels else 1
        # the block in the layer is a sequence of blocks
        self.blocks = nn.Sequential(
             block(in_channels , out_channels, *args, **kwargs, downsampling=downsampling), # first block with downsample
            *[block(out_channels * block.expansion, 
                    out_channels, downsampling=1, *args, **kwargs) for _ in range(n - 1)]  # all other blocks no downsampling, possible expantion
        )

    def forward(self, x):
        x = self.blocks(x)
        return x

In [94]:
layer = ResNetLayer(64, 64, block=ResNetBasicBlock, n=5)
del layer

In [93]:
del block

# Test

In [82]:
class ResNet(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1,self.stride)
        self.layer = ResNetLayer(self.out_channels, self.out_channels, block=ResNetBasicBlock, n=5)
        #self.encoder = ResNetEncoder(in_channels, *args, **kwargs)
        #self.decoder = ResnetDecoder(self.encoder.blocks[-1].blocks[-1].expanded_channels, n_classes)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = layer(x)
        #x = self.encoder(x)
        #x = self.decoder(x)
        return x

In [84]:
d = 2
input_channel = 730
f = input_channel // d
conv1 = conv1x1(730, d)
stride_first_layer = 2
ResNetModel = ResNet(input_channel,f,stride_first_layer)
ResNetModel

ResNet(
  (conv1): Conv2d(730, 365, kernel_size=(1, 1), stride=(2, 2))
  (layer): ResNetLayer(
    (blocks): Sequential(
      (0): ResNetBasicBlock(
        (blocks): Sequential(
          (0): Sequential(
            (conv): Conv2dAuto(365, 365, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(365, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): ReLU()
          (2): Sequential(
            (conv): Conv2dAuto(365, 365, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(365, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (shortcut): None
      )
      (1): ResNetBasicBlock(
        (blocks): Sequential(
          (0): Sequential(
            (conv): Conv2dAuto(365, 365, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(365, eps=1e-05, momentum=0.1, affine=True, track_ru

In [85]:
import numpy as np
x = torch.ones((730,300,600))

In [86]:
x2 = ResNetModel(x[None, ...])

RuntimeError: Given groups=1, weight of size 64 64 3 3, expected input[1, 365, 150, 300] to have 64 channels, but got 365 channels instead

In [60]:
x2.shape

torch.Size([300, 300, 300, 1])

In [74]:
conv1x1 = partial(Conv2dAuto, kernel_size=1, bias=False)  

In [75]:
d = 2
input_channel = 730
f = input_channel / d
conv1 = conv1x1(730, d)

import numpy as np
x = torch.ones((730,300,600))
x2 = ResNetModel(x[None, ...])