In [None]:
import os
import numpy as np
import pickle
import logging
import pathlib
import datetime
import time
import glob
import collections
import tqdm
import argparse


def spatial(args):
    import skimage.io

    window = 224  # only to check if patch is off of boundary
 

    logger = logging.getLogger(__name__)

    pathlib.Path(args.dest).mkdir(parents=True, exist_ok=True)

    raw, subtype = load_raw(args.root)

    with open(args.dest + "/subtype.pkl", "wb") as f:
        pickle.dump(subtype, f)

    t = time.time()

    t0 = time.time()
    section_header = None
    gene_names = set()
    for patient in raw:
        for section in raw[patient]:
            section_header = raw[patient][section]["count"].columns.values[0]
            gene_names = gene_names.union(set(raw[patient][section]["count"].columns.values[1:]))
    gene_names = list(gene_names)
    gene_names.sort()
    with open(args.dest + "/gene.pkl", "wb") as f:
        pickle.dump(gene_names, f)
    gene_names = [section_header] + gene_names
    logger.info("Finding list of genes: " + str(time.time() - t0))

    for (i, patient) in enumerate(raw):
        logger.info("Processing " + str(i + 1) + " / " + str(len(raw)) + ": " + patient)

        for section in raw[patient]:

            pathlib.Path("{}/{}/{}".format(args.dest, subtype[patient], patient)).mkdir(parents=True, exist_ok=True)

            # This is just a blank file to indicate that the section has been completely processed.
            # Preprocessing occassionally crashes, and this lets the preparation restart from where it let off
            complete_filename = "{}/{}/{}/.{}".format(args.dest, subtype[patient], patient, section)
            if pathlib.Path(complete_filename).exists():
                logger.info("Patient {} section {} has already been processed.".format(patient, section))
            else:
                logger.info("Processing " + patient + " " + section + "...")

                # In the original data, genes with no expression in a section are dropped from the table.
                # This adds the columns back in so that comparisons across the sections can be done.
                t0 = time.time()
                missing = list(set(gene_names) - set(raw[patient][section]["count"].keys()))
                c = raw[patient][section]["count"].values[:, 1:].astype(float)
                pad = np.zeros((c.shape[0], len(missing)))
                c = np.concatenate((c, pad), axis=1)
                names = np.concatenate((raw[patient][section]["count"].keys().values[1:], np.array(missing)))
                c = c[:, np.argsort(names)]
                logger.info("Adding zeros and ordering columns: " + str(time.time() - t0))

                t0 = time.time()
                count = {}
                for (j, row) in raw[patient][section]["count"].iterrows():
                    count[row.values[0]] = c[j, :]
                logger.info("Extracting counts: " + str(time.time() - t0))

                t0 = time.time()


                # tumor = {}
                # not_int = False
                # for (_, row) in raw[patient][section]["tumor"].iterrows():
                #     if isinstance(row[1], float) or isinstance(row[2], float):
                #         not_int = True
                #     tumor[(int(round(row[1])), int(round(row[2])))] = (row[4] == "tumor")
                # if not_int:
                #     logger.warning("Patient " + patient + " " + section + " has non-integer patch coordinates.")
                # logger.info("Extracting tumors: " + str(time.time() - t0))

                t0 = time.time()
                image = skimage.io.imread(raw[patient][section]["image"])
                logger.info("Loading image: " + str(time.time() - t0))

                # data = []
                for (_, row) in raw[patient][section]["spot"].iterrows():

                    # x = int(round(row["pixel_x"]))
                    # y = int(round(row["pixel_y"]))
                    
         

                    x = int(round(float(row[0].split(',')[1])))   # coord
                    y = int(round(float(row[0].split(',')[2])))

                    spot_x = int(str(row.values[0].split("x")[0])) # spot id
                    spot_y = int(str(row.values[0].split("x")[1].split(',')[0]))    



                    X = image[(y + (-window // 2)):(y + (window // 2)), (x + (-window // 2)):(x + (window // 2)), :]


                    if X.shape == (window, window, 3):

                        # if (int(row["x"]), int(row["y"])) in tumor:

                        if (str(spot_x) + "x" + str(spot_y)) in list(count.keys()) :
                            # data.append((X,
                            #              count[str(int(row["x"])) + "x" + str(int(row["y"]))],
                            #             #  tumor[(int(row["x"]), int(row["y"]))],
                            #              np.array([x, y]),
                            #              np.array([patient]),
                            #              np.array([section]),
                            #              np.array([int(row["x"]), int(row["y"])]),
                            #              ))
                            filename = "{}/{}/{}/{}_{}_{}.npz".format(args.dest, subtype[patient], patient, section,
                                                                      spot_x, spot_y)
                            np.savez_compressed(filename, count=count[str(spot_x) + "x" + str(spot_y)],
                                                # tumor=tumor[(int(row["x"]), int(row["y"]))],
                                                pixel=np.array([x, y]),
                                                patient=np.array([patient]),
                                                section=np.array([section]),
                                                index=np.array([spot_x, spot_y]))




                        else:
                            logger.warning("Patch " + str(spot_x) + "x" + str(spot_y) + " not found in " + patient + " " + section)
                    else:
                        logger.warning("Detected spot too close to edge.")
                logger.info("Saving patches: " + str(time.time() - t0))

                with open(complete_filename, "w"):
                    pass
    logger.info("Preprocessing took " + str(time.time() - t) + " seconds")

    if (not os.path.isfile("data/hist2tscript-patch/mean_expression.npy") or
        not os.path.isfile("data/hist2tscript-patch/median_expression.npy")):


        logging.info("Computing statistics of dataset")
        gene = []
        for filename in tqdm.tqdm(glob.glob("{}/*/*/*_*_*.npz".format(args.dest))):
            npz = np.load(filename)
            count = npz["count"]
            gene.append(np.expand_dims(count, 1))


        gene = np.concatenate(gene, 1)

        logging.info( "There are {} genes and {} spots in total for all the patients and sections.".format(gene.shape[0], gene.shape[1]))

        np.save( "data/hist2tscript-patch/mean_expression.npy", np.mean(gene, 1))
        np.save("data/hist2tscript-patch/median_expression.npy", np.median(gene, 1))


def newer_than(file1, file2):
    """
    Returns True if file1 is newer than file2.
    A typical use case is if file2 is generated using file1.
    For example:

    if newer_than(file1, file2):
        # update file2 based on file1
    """
    return os.path.isfile(file1) and (not os.path.isfile(file2) or os.path.getctime(file1) > os.path.getctime(file2))


def load_section(root: str, patient: str, section: str, subtype: str):
    """
    Loads data for one section of a patient.
    """
    import pandas
    import gzip

    file_root = root + "/" + subtype + "/" + patient + "/" + patient + "_" + section

    # image = skimage.io.imread(file_root + ".jpg")
    image = file_root + ".jpg"

    if newer_than(file_root + ".tsv.gz", file_root + ".pkl"):
        with gzip.open(file_root + ".tsv.gz", "rb") as f:
            count = pandas.read_csv(f, sep="\t")
        with open(file_root + ".pkl", "wb") as f:
            pickle.dump(count, f)
    else:
        with open(file_root + ".pkl", "rb") as f:
            count = pickle.load(f)

    if newer_than(file_root + ".spots.gz", file_root + ".spots.pkl"):
        spot = pandas.read_csv(file_root + ".spots.gz", sep="\t")
        with open(file_root + ".spots.pkl", "wb") as f:
            pickle.dump(spot, f)
    else:
        with open(file_root + ".spots.pkl", "rb") as f:
            spot = pickle.load(f)


    return {"image": image, "count": count, "spot": spot}


def load_raw(root: str):
    """
    Loads data for all patients.
    """

    logger = logging.getLogger(__name__)

    # Wildcard search for patients/sections

    images = glob.glob(root + "/*/*/*_*.jpg")

    #   file_root = root + "/" + subtype + "/" + patient + "/" + patient + "_" + section
 


    # data/hist2tscript/HER2nonluminal/BC23567/    HE_BC23567_E2.jpg

    # Dict mapping patient ID (str) to a list of all sections available for the patient (List[str])  sections: C/D
    patient = collections.defaultdict(list)
    for (p, s) in map(lambda x: x.split("/")[-1][:-4].split("_"), images):
        patient[p].append(s)

    # Dict mapping patient ID (str) to subtype (str)
    subtype = {}
    for (st, p) in map(lambda x: (x.split("/")[-3], x.split("/")[-1][:-4].split("_")[0]), images):
        if p in subtype:
            if subtype[p] != st:
                raise ValueError("Patient {} is the same marked as type {} and {}.".format(p, subtype[p], st))
        else:
            subtype[p] = st

    logger.info("Loading raw data...")
    t = time.time()
    data = {}
    with tqdm.tqdm(total=sum(map(len, patient.values()))) as pbar:
        for p in patient:
            data[p] = {}
            for s in patient[p]:
                data[p][s] = load_section(root, p, s, subtype[p])
                pbar.update()
    logger.info("Loading raw data took " + str(time.time() - t) + " seconds.")

    return data, subtype



parser = argparse.ArgumentParser(description='Process the paths.')

parser.add_argument('--root',  type=str, default='data/hist2tscript/',
                    help='an integer for the accumulator')     
parser.add_argument('--dest',  type=str, default='data/hist2tscript-patch/',
                    help='an integer for the accumulator')

args = parser.parse_args()

spatial(args)



In [2]:
import numpy as np

In [5]:
np.zeros((-1, 1))

ValueError: negative dimensions are not allowed

In [19]:
np.zeros((3,4))

array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])

In [27]:
t = np.array([[1., 0., 2., 0.],
              [0., 0., 0., 9.],
              [0., 0., 0., 0.]])

In [66]:
np.sum(t, 1)

array([3., 9., 0.])

In [68]:
np.sum(t, 1)!=0 

array([ True,  True, False])

In [78]:
b = ((np.sum(t, 1)!=0)*1 + ((np.sum(np.array(t!=0),1))>=(0.1*t.shape[1]))*1)==2

In [79]:
t[b]

array([[1., 0., 2., 0.],
       [0., 0., 0., 9.]])

In [None]:
with open("/home/chenxingjian/PycharmProjects/mnt/mnt_project/FromZerotoOne/data/hist2tscript-patch/gene.pkl","rb") as f:   
    gene = pickle.load(f)

In [57]:
(np.sum(np.array(t!=0),1))

array([2, 1, 0])

In [60]:
((np.sum(np.array(t!=0),1))>=(0.1*t.shape[1]))

array([ True,  True, False])

In [74]:
((np.sum(np.array(t!=0),1))>=(0.1*t.shape[1]))*1

array([1, 1, 0])

In [61]:
t[((np.sum(np.array(t!=0),1))>=(0.1*t.shape[1]))]

array([[1., 0., 2., 0.],
       [0., 0., 0., 9.]])

In [53]:
((np.sum(np.array(t!=0),1))!=0).shape

(3,)

In [47]:
list(np.nonzero(np.sum(np.array(t!=0),1))[0])

[0, 1]

In [None]:
import torch
import torchvision
import numpy as np
import logging
import pathlib
import traceback
import random
import time
import os
import glob
import socket
import argparse
import collections
import utils
from efficientnet_pytorch import EfficientNet




def run_spatial(args=None):

    
        ### Seed ###
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

        ### Select device for computation ###
        device = ("cuda" if args.gpu else "cpu")

        ### Split patients into folds ###
        patient = get_spatial_patients()
        train_patients = []
        test_patients = []
        for (i, p) in enumerate(patient):
            for s in patient[p]:
                if p in args.testpatients or (p, s) in args.testpatients:
                    test_patients.append((p, s))
                else:
                    train_patients.append((p, s))

        ### Dataset setup ###  

        window = args.window



        # here need to be changed for split training data and test data

        train_dataset = utils.Spatial(train_patients, window=window, gene_filter=args.gene_filter, 
                transform=torchvision.transforms.ToTensor())
        
        # print(len(train_dataset)) #29678



        train_size = int(0.9 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])


        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, 
                                    num_workers=args.workers, shuffle=True, pin_memory=args.gpu)


        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch, 
                                    num_workers=args.workers, shuffle=True, pin_memory=args.gpu)







        # Estimate mean and covariance
        t = time.time()
        n_samples = 10
        mean = 0.
        std = 0.
        n = 0
        for (i, (X, *_)) in enumerate(train_loader):
            X = X.transpose(0, 1).contiguous().view(3, -1)
            n += X.shape[1]
            mean += torch.sum(X, dim=1)
            std += torch.sum(X ** 2, dim=1)
            if i > n_samples:
                break
        mean /= n
        std = torch.sqrt(std / n - mean ** 2)
        print("Estimating mean (" + str(mean) + ") and std (" + str(std) + " took " + str(time.time() - t) + 's')





        # Transform and data argumentation (TODO: spatial and multiscale ensemble)

        transform = []
        transform.extend([torchvision.transforms.RandomHorizontalFlip(),
                          torchvision.transforms.RandomVerticalFlip(),
                          torchvision.transforms.RandomApply([torchvision.transforms.RandomRotation((90, 90))]),
                          torchvision.transforms.ToTensor(),
                          torchvision.transforms.Normalize(mean=mean, std=std)])
        transform = torchvision.transforms.Compose(transform)
        # for training data
        train_dataset.transform = transform

        # for val data
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                        torchvision.transforms.Normalize(mean=mean, std=std)])
        val_dataset.transform = transform

        # for test data
        if args.average:
            transform = torchvision.transforms.Compose([utils.transforms.EightSymmetry(),
                                                        torchvision.transforms.Lambda(lambda symmetries: torch.stack([torchvision.transforms.Normalize(mean=mean, std=std)(torchvision.transforms.ToTensor()(s)) for s in symmetries]))])
        else:
            transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                        torchvision.transforms.Normalize(mean=mean, std=std)])

        test_dataset = utils.Spatial(test_patients, transform, window=args.window, gene_filter=args.gene_filter)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch, num_workers=args.workers, shuffle=True, pin_memory=args.gpu)


        # Find number of required outputs

        outputs = train_dataset[0][1].shape[0]
           

        # ### Model setup ###
        # model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
        start_epoch = 0
        # ###Changes number of outputs for the model, return model###
        # utils.nn.set_out_features(model, outputs)
        # if args.gpu:
        #     model = torch.nn.DataParallel(model)
        # model.to(device)

        # ### Optimizer setup ###
        # # chose the parameters that need to be optimized  #
        # parameters = utils.nn.get_finetune_parameters(model, args.finetune, args.randomize)


        model = EfficientNet.from_pretrained('efficientnet-b7')
        feature = model._fc.in_features
        model._fc = torch.nn.Linear(in_features=feature,out_features=outputs)

        # fix all the layers before fc layer, set model[-1]
        # print(model)
        # for parameter in model.parameters():
        #     parameter.requires_grad = False
        # for parameter in model.fc.parameters():
        #     parameter.requires_grad = True
        model.to(device)
        # parameters = utils.nn.get_finetune_parameters(model, args.finetune, args.randomize)
        optim = torch.optim.__dict__[args.optim](model._fc.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)


        # Compute mean expression as baseline
        
        # t = time.time()
        # mean_expression = torch.zeros(train_dataset[0][1].shape)
       
        # for (i, (_, gene, *_)) in enumerate(train_loader):
        #     print("{:8d} / {:d}:    {:4.0f} / {:4.0f} seconds".format(i + 1, len(train_loader), time.time() - t, (time.time() - t) * len(train_loader) / (i + 1)), end="\r", flush=True)

        #     mean_expression += torch.sum(gene, 0)/gene.shape[0]
        # mean_expression /= len(train_loader)
        # mean_expression = mean_expression.to(device)
        # print("Computing mean expression took {}".format(time.time() - t))



        ### Training Loop ###   save for every epoch but actually we just need to save the last one
        # save the best model and npz file

        for epoch in range(start_epoch, args.epochs):
            print("Epoch #" + str(epoch + 1))

            # for each epoch, loop train, val and test

            for (dataset, loader) in [("train", train_loader), ("val", val_loader), ("test", test_loader)]:

                t = time.time()

                if dataset == "train":
                    torch.set_grad_enabled(True)
                    model.train()
                else:
                    torch.set_grad_enabled(False)
                    model.eval()



                total = 0
                total_mean = 0
                n=0
                genes = []
                predictions = []
                counts = []
                coord = []
                patient = []
                section = []
                pixel = []


                print(dataset + ":")
                for (i, (X, gene, c, ind, pat, s, pix)) in enumerate(loader):
                    
                    print("********",i)


                    counts.append(gene.detach().numpy())
                    coord.append(c.detach().numpy())
                    patient += pat
                    section += s
                    pixel.append(pix.detach().numpy())

                    X = X.to(device)
                    gene = gene.to(device)

                    if dataset == "test" and args.average:
                        batch, n_sym, c, h, w = X.shape
                        X = X.view(-1, c, h, w)
                    

                    pred = model(X)  # [32, 5943])

                    if dataset == "test" and args.average:
                        pred = pred.view(batch, n_sym, -1).mean(1)



                    predictions.append(pred.cpu().detach().numpy())
                    loss = torch.sum((pred - gene) ** 2) / outputs # in a epoch, one batch average gene loss

                    # print(loss/gene.shape[0])

                    total += loss.cpu().detach().numpy()
                    n += gene.shape[0]  #[32, 250]


                    message = ""
                    message += "Batch: {:8d} / {:d} ({:4.0f} / {:4.0f}):".format(i + 1, len(loader), time.time() - t, (time.time() - t) * len(loader) / (i + 1))
                    message += "    Batch-based Loss={:.9f}".format(total / n)  # loop each batch, print averaged batch-gene loss
                    print(message)

                      
                    if dataset == "train" :
                        optim.zero_grad()
                        loss.backward()
                        optim.step()



                print("    Epoch-based Loss:       " + str(total / len(loader.dataset)))
     
                # one epoch finished, and for the last is test loop, we save
                predictions = np.concatenate(predictions)
                counts = np.concatenate(counts)
                coord = np.concatenate(coord)
                pixel = np.concatenate(pixel)
                # me = mean_expression.cpu().numpy(),  # this is training mean_expression
    #                   

                pathlib.Path(os.path.dirname(args.pred_root)).mkdir(parents=True, exist_ok=True)
                np.savez_compressed(args.pred_root + str(epoch + 1),
                                    task="gene",
                                    counts=counts,
                                    predictions=predictions,
                                    coord=coord,
                                    patient=patient,
                                    section=section,
                                    pixel=pixel,
                                    # mean_expression=me,
                                    ensg_names=test_dataset.ensg_names,
                                    gene_names=test_dataset.gene_names,
                )


                # Saving after test so that if information from test is needed, they will not get skipped
                # if dataset == "test" and args.checkpoint is not None and ((epoch + 1) % args.checkpoint_every) == 0 and args.model != "rf":
                #     pathlib.Path(os.path.dirname(args.checkpoint)).mkdir(parents=True, exist_ok=True)
                  

                #     torch.save({
                #         'model': model.state_dict(),
                #         'optim' : optim.state_dict(),
                #     }, args.checkpoint + str(epoch + 1) + ".pt")

                #     if epoch != 0 and (args.keep_checkpoints is None or (epoch + 1 - args.checkpoint_every) not in args.keep_checkpoints):
                #         os.remove(args.checkpoint + str(epoch + 1 - args.checkpoint_every) + ".pt")




def get_spatial_patients():
    """
    Returns a dict of patients to sections.

    The keys of the dict are patient names (str), and the values are lists of
    section names (str).
    """
    patient_section = map(lambda x: x.split("/")[-1].split(".")[0].split("_"), glob.glob("data/hist2tscript/*/*/*.jpg"))
    patient = collections.defaultdict(list)
    for (p, s) in patient_section:
        patient[p].append(s)
    return patient


def patient_or_section(name):
        if "_" in name:
            return tuple(name.split("_"))
        return name

parser = argparse.ArgumentParser(description='Process the paths.')





parser.add_argument('--seed', '-s', type=int, default=0, help='RNG seed')
parser.add_argument("--gpu", action="store_true", help="use GPU")




parser.add_argument("--testpatients", nargs="*", type=patient_or_section, default=None,
                                   help="all the rest patients will be used as the training data")
parser.add_argument("--window", type=int, default=224, help="window size")
parser.add_argument("--gene_filter", choices=["none", "high", 250], default=250,
                       help="special gene filters")
parser.add_argument("--batch", type=int, default=256, help="training batch size")    
parser.add_argument("--workers", type=int, default=4, help="number of workers for dataloader")





parser.add_argument("--model", "-m", default="vgg11",
                        # choices=sorted(name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and callable(torchvision.models.__dict__[name])),  TODO: autocomplete speed issue
                        help="model architecture")
parser.add_argument("--pretrained", action="store_true",
                    help="use ImageNet pretrained weights")
parser.add_argument("--finetune", type=int, nargs="?", const=1, default=None,
                             help="fine tune last n layers")
parser.add_argument("--randomize", action="store_true",
                                   help="randomize weights in layers to be fined tuned")



parser.add_argument("--optim", default="SGD",
                        # choices=sorted(name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and callable(torchvision.models.__dict__[name])),  TODO: autocomplete speed issue and change to optim instead of model
                        help="optimizer")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--momentum", type=float, default=0.9, help="momentum for SGD")
parser.add_argument("--weight_decay", type=float, default=0, help="weight decay for SGD")
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")


parser.add_argument("--average", action="store_true", help="average between rotations and reflections")
parser.add_argument("--pred_root", type=str, default=None, help="root for prediction outputs")

args = parser.parse_args()

run_spatial(args)



In [1]:
import numpy as np

In [26]:
a=np.load('BC23287_1.npz')
b=a['counts']
c=a['predictions']
for i in range(batch):
    print(np.sum((b[i*batch:(i+1)*batch,:] - c[i*batch:(i+1)*batch,:]) ** 2)/(250*batch))

0.9193314819335937
1.0168716430664062
0.87110302734375
0.8447691040039063
0.99274658203125
0.9404971313476562
0.9094579467773437
0.9560597534179688
0.830826416015625
0.86204931640625
0.9589568481445313
0.8984053955078125
0.9861575317382812
0.902651611328125
0.8830653686523438
0.88997119140625
0.86878125
0.950525146484375
0.9787474975585938
0.9335302734375
0.88711865234375
0.9798858642578125
0.9214951171875
0.9374384765625
1.0119865112304687
0.9349647216796875
0.902815185546875
0.94776904296875
0.87917236328125
0.5570457763671876
0.0
0.0


In [27]:
a=np.load('/home/chenxingjian/PycharmProjects/mnt/mnt_project/FromZerotoOne/output/1.npz')
b=a['counts']
c=a['predictions']
for i in range(32):
    print(np.sum((b[i*32:(i+1)*32,:] - c[i*32:(i+1)*32,:]) ** 2)/(250*32))

1.06476171875
1.0140812377929687
1.1718956298828125
1.0427177734375
0.9164627075195313
1.0842373046875
1.062562744140625
1.1283795166015624
0.9940906982421875
1.124470703125
1.07689453125
0.9617781982421875
1.077207763671875
1.06529736328125
1.157640625
1.09458984375
1.0609080810546876
1.12093408203125
0.9850051879882813
1.064634765625
1.013691650390625
1.026901123046875
1.17226220703125
0.9928989868164062
1.0620875244140624
0.948853759765625
1.1373251953125
1.0010823974609375
0.9962522583007812
0.67506689453125
0.0
0.0


In [28]:
import torch
np.sum((b - c) ** 2) / (250*947)

1.0574857444561774

In [11]:
b=a['counts']

In [12]:
c=a['predictions']

In [15]:
b.shape

(947, 250)

In [24]:
for i in range(32):
    print(i*32,(i+1)*32)
    print(np.sum((b[i*32:(i+1)*32,:] - c[i*32:(i+1)*32,:]) ** 2)/(250*32))

0 32
0.9193314819335937
32 64
1.0168716430664062
64 96
0.87110302734375
96 128
0.8447691040039063
128 160
0.99274658203125
160 192
0.9404971313476562
192 224
0.9094579467773437
224 256
0.9560597534179688
256 288
0.830826416015625
288 320
0.86204931640625
320 352
0.9589568481445313
352 384
0.8984053955078125
384 416
0.9861575317382812
416 448
0.902651611328125
448 480
0.8830653686523438
480 512
0.88997119140625
512 544
0.86878125
544 576
0.950525146484375
576 608
0.9787474975585938
608 640
0.9335302734375
640 672
0.88711865234375
672 704
0.9798858642578125
704 736
0.9214951171875
736 768
0.9374384765625
768 800
1.0119865112304687
800 832
0.9349647216796875
832 864
0.902815185546875
864 896
0.94776904296875
896 928
0.87917236328125
928 960
0.5570457763671876
960 992
0.0
992 1024
0.0


In [25]:
import torch
np.sum((b - c) ** 2) / (250*947)

0.9243233896515312

In [16]:
875.33425/947

0.9243233896515312