In [None]:
import torch
import numpy as np
from torch import nn

#**sections A + B**

In [None]:
class SplitLinear():
  def __init__(self, num_of_features, out_features=2):
    #the built-in linear aggregation function and ReLU
    self.m = nn.Linear(int((num_of_features/2)), out_features)
    self.r = nn.ReLU()
    self.num_of_features =num_of_features
  
  def forward(self,x):  
    num_of_samples = x.shape[0]
    tmp1, tmp2 = torch.split(x, int((x.shape[1]/2)), dim=1)

    z1 = self.m(tmp1)
    z2 = self.m(tmp2)
    y1 = self.r(z1)
    y2 = self.r(z2)

    y = torch.cat((y1,y2), dim=1)
    return y

In [None]:
import sklearn.datasets as skds
x, y = skds.make_blobs(n_samples=10,n_features=4,
                       centers=2,random_state=1)
print(f'this is the original x:\n{x}')

this is the original x:
[[-0.79415228  2.10495117 -8.25290074 -4.71455545]
 [-7.25671774 -9.04085707 -7.02195407 -1.39633086]
 [-1.34052081  4.15711949 -8.53560457 -6.01348926]
 [-1.83198811  3.52863145 -9.95549876 -3.37053333]
 [-7.75205488 -8.99843375 -6.9460419  -3.10145006]
 [-1.98197711  4.02243551 -8.86394306 -5.05323981]
 [-6.16402623 -8.83695596 -6.397686   -4.02455489]
 [-2.76017908  5.55121358 -9.09612178 -3.45085421]
 [-7.33277026 -7.62287264 -6.96645652 -3.48553899]
 [-8.18219253 -7.91881241 -4.6149936  -2.3467413 ]]


In [None]:
fl = x.flatten()
print(f'this is the flat x:\n{fl}')

this is the flat x:
[-0.79415228  2.10495117 -8.25290074 -4.71455545 -7.25671774 -9.04085707
 -7.02195407 -1.39633086 -1.34052081  4.15711949 -8.53560457 -6.01348926
 -1.83198811  3.52863145 -9.95549876 -3.37053333 -7.75205488 -8.99843375
 -6.9460419  -3.10145006 -1.98197711  4.02243551 -8.86394306 -5.05323981
 -6.16402623 -8.83695596 -6.397686   -4.02455489 -2.76017908  5.55121358
 -9.09612178 -3.45085421 -7.33277026 -7.62287264 -6.96645652 -3.48553899
 -8.18219253 -7.91881241 -4.6149936  -2.3467413 ]


In [None]:
t = SplitLinear(4)
data = torch.from_numpy(x).float()
y = t.forward(data)
print(y)

tensor([[0.7733, 0.7404, 5.8278, 2.4732],
        [5.7327, 0.7607, 4.7507, 2.8074],
        [0.8610, 1.5890, 6.1346, 2.2447],
        [1.2139, 1.6525, 6.6602, 3.6891],
        [6.0135, 1.0159, 4.8958, 2.2879],
        [1.2456, 1.8658, 6.2175, 2.6774],
        [5.0803, 0.2820, 4.6820, 1.7577],
        [1.5248, 2.6801, 6.1738, 3.2445],
        [5.6195, 1.1991, 4.9501, 2.1893],
        [6.1419, 1.5323, 3.4687, 1.3571]], grad_fn=<CatBackward0>)


#**sections D + E + F in the google docs**