In [1]:
import torch 

In [2]:
print(torch.version.cuda)

11.8


In [2]:
class MySharedTensor(object):
    def __init__(self, tensors, dim=None):
        assert dim is not None
        self.dim = dim
        assert (isinstance(tensors, list))
        assert (torch.is_tensor(t) for t in tensors)
        self.tensors = tensors
        self.dim_sizes = [t.size(self.dim) for t in self.tensors]

    # If you want to recover a single Tensor from it
    def to_full_tensor(self):
        return torch.cat(self.tensors, dim=self.dim)

    # Out of place addition
    def __add__(self, other):
        assert torch.is_tensor(other)
        out = other.clone()
        curr_idx = 0
        for i, t in enumerate(self.tensors):
            other_slice = other.narrow(self.dim, curr_idx, self.dim_sizes[i])
            out.narrow(self.dim, curr_idx, self.dim_sizes[i]).copy_(t).add_(other_slice)
            curr_idx += self.dim_sizes[i]
        return out

    # Inplace add
    def __iadd__(self, other):
        assert torch.is_tensor(other)
        curr_idx = 0
        for i, t in enumerate(self.tensors):
            other_slice = other.narrow(self.dim, curr_idx, self.dim_sizes[i])
            t.add_(other_slice)
            curr_idx += self.dim_sizes[i]
        return self

    # Matrix Multiplication (only 2d matrices for simplicity)
    def mm(self, other):
        assert other.ndimension() == 2
        assert all(t.ndimension() == 2 for t in self.tensors)

        if self.dim == 0:
            out_tensors = []
            for t in self.tensors:
                out_tensors.append(t.mm(other))
            return MySharedTensor(out_tensors, dim=0)
        elif self.dim == 1:
            out = 0
            curr_idx = 0
            for i, t in enumerate(self.tensors):
                other_slice = other.narrow(0, curr_idx, self.dim_sizes[i])
                out += t.mm(other_slice)
                curr_idx += self.dim_sizes[i]
            return out
        else:
            raise RuntimeError("Invalid dimension")


a = torch.rand(2, 4)
b = torch.rand(2, 4)
print("a and b")
print(a)
print(b)

c = MySharedTensor([a, b], dim=1)
print("c size:")
print(c.to_full_tensor().size())

d = torch.rand(2, 8)
print("d")
print(d)

e = c + d
print("e = c + d")
print(e)

f = c.to_full_tensor() + d
print("f = c.to_full_tensor() + d")
print(f)

c += d
print("c += d")
print(c.to_full_tensor())
print("a and b are changed:")
print(a, b)


g = torch.rand(8, 3)
print("g")
print(g)

h = c.mm(g)
print("h = c.mm(g)")
print(h)

k = c.to_full_tensor().mm(g)
print("k = c.to_full_tensor().mm(g)")
print(k)

a and b
tensor([[0.4566, 0.4961, 0.3088, 0.9393],
        [0.6863, 0.8910, 0.9775, 0.2971]])
tensor([[0.9515, 0.2214, 0.6308, 0.9002],
        [0.3194, 0.7220, 0.9200, 0.7176]])
c size:
torch.Size([2, 8])
d
tensor([[0.5807, 0.5835, 0.6693, 0.9280, 0.5016, 0.8280, 0.2869, 0.8108],
        [0.6894, 0.1667, 0.0125, 0.8738, 0.9797, 0.9346, 0.6939, 0.1501]])
e = c + d
tensor([[1.0373, 1.0796, 0.9781, 1.8673, 1.4531, 1.0494, 0.9176, 1.7110],
        [1.3757, 1.0577, 0.9900, 1.1709, 1.2991, 1.6566, 1.6138, 0.8677]])
f = c.to_full_tensor() + d
tensor([[1.0373, 1.0796, 0.9781, 1.8673, 1.4531, 1.0494, 0.9176, 1.7110],
        [1.3757, 1.0577, 0.9900, 1.1709, 1.2991, 1.6566, 1.6138, 0.8677]])
c += d
tensor([[1.0373, 1.0796, 0.9781, 1.8673, 1.4531, 1.0494, 0.9176, 1.7110],
        [1.3757, 1.0577, 0.9900, 1.1709, 1.2991, 1.6566, 1.6138, 0.8677]])
a and b are changed:
tensor([[1.0373, 1.0796, 0.9781, 1.8673],
        [1.3757, 1.0577, 0.9900, 1.1709]]) tensor([[1.4531, 1.0494, 0.9176, 1.7110],
     