In [1]:
import $file.^.Magic

[32mimport [39m[36m$file.$[39m

In [2]:
val zipName = "sms+spam+collection.zip"
val datasetUrl = s"https://archive.ics.uci.edu/static/public/228/$zipName"
val outputDir = "data/sms-spam-raw"

[36mzipName[39m: [32mString[39m = [32m"sms+spam+collection.zip"[39m
[36mdatasetUrl[39m: [32mString[39m = [32m"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"[39m
[36moutputDir[39m: [32mString[39m = [32m"data/sms-spam-raw"[39m

In [3]:
Magic.!("curl", "--create-dirs", "-O", "--output-dir", outputDir, datasetUrl)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 65461    0 65461    0     0  69798      0 --:--:-- --:--:-- --:--:-- 69787
100  198k    0  198k    0     0   175k      0 --:--:--  0:00:01 --:--:--  175k


In [4]:
Magic.!("unzip", "-o", s"$outputDir/$zipName", "-d", outputDir)

Archive:  data/sms-spam-raw/sms+spam+collection.zip
  inflating: data/sms-spam-raw/SMSSpamCollection  
  inflating: data/sms-spam-raw/readme  


In [5]:
import scala.io.Source

val datasetRaw = Source.fromFile(s"$outputDir/SMSSpamCollection").mkString

case class SmsSpamRecord(
  text: String,
  isSpam: Boolean
)

type Dataset = Vector[SmsSpamRecord]

val smsSpamRecords: Dataset = datasetRaw.split("\n").map {
  case s"spam\t$text" => SmsSpamRecord(text, isSpam = true)
  case s"ham\t$text" => SmsSpamRecord(text, isSpam = false)
}.toVector

val (spamRecords, notSpamRecords) = smsSpamRecords.partition(_.isSpam)
println(s"Spam count: ${spamRecords.size}")
println(s"Not spam count: ${notSpamRecords.size}")

Spam count: 747
Not spam count: 4827


[32mimport [39m[36mscala.io.Source[39m
[36mdatasetRaw[39m: [32mString[39m = [32m"""ham	Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham	Ok lar... Joking wif u oni...
spam	Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
ham	U dun say so early hor... U c already then say...
ham	Nah I don't think he goes to usf, he lives around here though
spam	FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
ham	Even my brother is not like to speak with me. They treat me like aids patent.
ham	As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
spam	WINNER!! As a valued network customer you have been selected to receivea

In [6]:
import scala.collection.mutable
import scala.util.Random

val balancedDataset: Dataset = {

  def sample(records: Vector[SmsSpamRecord], targetSize: Int): Vector[SmsSpamRecord] = {
    val balancedDatasetSpam = mutable.Map[String, SmsSpamRecord]()
    while (balancedDatasetSpam.size < targetSize) {
      val randomRecord = records(Random.nextInt(records.size))
      if (!balancedDatasetSpam.contains(randomRecord.text))
        balancedDatasetSpam += randomRecord.text -> randomRecord
    }
    balancedDatasetSpam.values.toVector
  }

  if (spamRecords.size < notSpamRecords.size)
    spamRecords ++ sample(notSpamRecords, targetSize = spamRecords.size)
  else
    notSpamRecords ++ sample(spamRecords, targetSize = notSpamRecords.size)
}

[32mimport [39m[36mscala.collection.mutable[39m
[32mimport [39m[36mscala.util.Random[39m
[36mbalancedDataset[39m: [32mDataset[39m = [33mVector[39m(
  [33mSmsSpamRecord[39m(
    text = [32m"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"[39m,
    isSpam = [32mtrue[39m
  ),
  [33mSmsSpamRecord[39m(
    text = [32m"FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv"[39m,
    isSpam = [32mtrue[39m
  ),
  [33mSmsSpamRecord[39m(
    text = [32m"WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."[39m,
    isSpam = [32mtrue[39m
  ),
  [33mSmsSpamRecord[39m(
    text = [32m"Had your mobile 11 months or more? U R entitled to Update to the latest colour m

In [7]:
type Training = Dataset
type Validation = Dataset
type Test = Dataset

def randomSplit(dataset: Vector[SmsSpamRecord], trainingFraction: Double, validationFraction: Double): (Training, Validation, Test) = {
  val shuffledDataset = Random.shuffle(dataset)
  val trainingSize = (shuffledDataset.size * trainingFraction).floor.toInt
  val validationSize = (shuffledDataset.size * validationFraction).floor.toInt

  val (training, remainingRecords) = shuffledDataset.splitAt(trainingSize)
  val (validation, test) = remainingRecords.splitAt(validationSize)
  (training, validation, test)
}

val (training, validation, test) = randomSplit(balancedDataset, trainingFraction = 0.7, validationFraction = 0.1) 

defined [32mtype[39m [36mTraining[39m
defined [32mtype[39m [36mValidation[39m
defined [32mtype[39m [36mTest[39m
defined [32mfunction[39m [36mrandomSplit[39m
[36mtraining[39m: [32mTraining[39m = [33mVector[39m(
  [33mSmsSpamRecord[39m(text = [32m"Ugh just got outta class"[39m, isSpam = [32mfalse[39m),
  [33mSmsSpamRecord[39m(
    text = [32m"Sorry da thangam, very very sorry i am held up with prasad."[39m,
    isSpam = [32mfalse[39m
  ),
  [33mSmsSpamRecord[39m(
    text = [32m"As a valued customer, I am pleased to advise you that following recent review of your Mob No. you are awarded with a £1500 Bonus Prize, call 09066368470"[39m,
    isSpam = [32mtrue[39m
  ),
  [33mSmsSpamRecord[39m(text = [32m"Bugis oso near wat... "[39m, isSpam = [32mfalse[39m),
  [33mSmsSpamRecord[39m(
    text = [32m"Hi I'm sue. I am 20 years old and work as a lapdancer. I love sex. Text me live - I'm i my bedroom now. text SUE to 89555. By TextOperator G2 1DA 15

In [8]:
import $ivy.`com.github.tototoshi::scala-csv:2.0.0`

import scala.util.Using
import com.github.tototoshi.csv.CSVWriter

val textHeader = "Text"
val labelHeader = "Label"

def writeToCsv(path: String, dataset: Dataset): Unit = {
  val headers = Vector(textHeader, labelHeader)

  Using.resource(CSVWriter.open(path)) { writer =>
    val rows = dataset.map {
      case SmsSpamRecord(text, isSpam) => Vector(text, if (isSpam) "1" else "0")
    }
    writer.writeAll(headers +: rows)
  }
}

val trainingCsv = "data/training.csv"
writeToCsv(trainingCsv, training)
val validationCsv = "data/validation.csv"
writeToCsv(validationCsv, validation)
val testCsv = "data/test.csv"
writeToCsv(testCsv, test)

[32mimport [39m[36m$ivy.$[39m
[32mimport [39m[36mscala.util.Using[39m
[32mimport [39m[36mcom.github.tototoshi.csv.CSVWriter[39m
[36mtextHeader[39m: [32mString[39m = [32m"Text"[39m
[36mlabelHeader[39m: [32mString[39m = [32m"Label"[39m
defined [32mfunction[39m [36mwriteToCsv[39m
[36mtrainingCsv[39m: [32mString[39m = [32m"data/training.csv"[39m
[36mvalidationCsv[39m: [32mString[39m = [32m"data/validation.csv"[39m
[36mtestCsv[39m: [32mString[39m = [32m"data/test.csv"[39m

In [9]:
Magic.!("pip", "install", "tiktoken==0.7.*")




[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: pip install --upgrade pip


In [10]:
import $ivy.`dev.scalapy::scalapy-core:0.5.3`

import me.shadaj.scalapy.py
import py.SeqConverters

val tiktoken = py.module("tiktoken")

val tokenizer = tiktoken.get_encoding("gpt2")
val endOfTextToken = "<|endoftext|>"
val encodedEndOfTextToken = tokenizer.encode(endOfTextToken, allowed_special = py.Dynamic.global.set(Seq(endOfTextToken).toPythonProxy))
println(encodedEndOfTextToken)

[50256]


[32mimport [39m[36m$ivy.$[39m
[32mimport [39m[36mme.shadaj.scalapy.py[39m
[32mimport [39m[36mpy.SeqConverters[39m
[36mtiktoken[39m: [32mpy[39m.[32mModule[39m = <module 'tiktoken' from '/usr/local/lib/python3.12/site-packages/tiktoken/__init__.py'>
[36mtokenizer[39m: [32mpy[39m.[32mDynamic[39m = <Encoding 'gpt2'>
[36mendOfTextToken[39m: [32mString[39m = [32m"<|endoftext|>"[39m
[36mencodedEndOfTextToken[39m: [32mpy[39m.[32mDynamic[39m = [50256]

In [11]:
Magic.!("pip", "install", "torch==2.4.*")




[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: pip install --upgrade pip


In [12]:
import com.github.tototoshi.csv.CSVReader
import py.PyQuote

val torch = py.module("torch")

type Tokenizer = py.Dynamic

// Workaround to define a class that inherits from a Python class
py.exec {
  s"""from torch.utils.data import Dataset
     |
     |class SpamDataset(Dataset):
     |  def __init__(self, init):
     |    init(self)
     |
     |  def __getitem__(self, index):
     |    return self.getItem(index)
     |  
     |  def __len__(self):
     |    return self.len()
     |""".stripMargin
}
def SpamDataset(
  csvPath: String,
  tokenizer: Tokenizer,
  maxLength: Option[Int] = None,
  paddingTokenId: Int = 50_256
): py.Dynamic = {
  val smsSpamRecords = Using.resource(CSVReader.open(csvPath)) { csvReader =>
    csvReader.iteratorWithHeaders.map { row =>
      SmsSpamRecord(text = row(textHeader), isSpam = row(labelHeader).toInt > 0)
    }.toVector
  }
  val encodedTexts = {
    val encodedTexts = smsSpamRecords.map(_.text).map(tokenizer.encode(_).as[Seq[Int]].toVector)
    val padToLength = maxLength.getOrElse(encodedTexts.map(_.length).max)
    encodedTexts.map(_.padTo(padToLength, paddingTokenId))
  }
    
  val init = (self: py.Dynamic) => {
    self.maxLength = encodedTexts.head.length
    
    val getItem = (index: Int) => {
      val textTensor = torch.tensor(encodedTexts(index).toPythonProxy, dtype = torch.long)
      val labelTensor = torch.tensor(if (smsSpamRecords(index).isSpam) 1 else 0, dtype = torch.long)
      (textTensor, labelTensor)
    }
    self.getItem = getItem

    val len = () => smsSpamRecords.size
    self.len = len
  }
  py.Dynamic.global.SpamDataset(init)
}

[32mimport [39m[36mcom.github.tototoshi.csv.CSVReader[39m
[32mimport [39m[36mpy.PyQuote[39m
[36mtorch[39m: [32mpy[39m.[32mModule[39m = <module 'torch' from '/usr/local/lib/python3.12/site-packages/torch/__init__.py'>
defined [32mtype[39m [36mTokenizer[39m
defined [32mfunction[39m [36mSpamDataset[39m

In [13]:
val trainingDataset = SpamDataset(trainingCsv, tokenizer)
val validationDataset = SpamDataset(validationCsv, tokenizer, maxLength = Some(trainingDataset.maxLength.as[Int]))
val testDataset = SpamDataset(testCsv, tokenizer, maxLength = Some(trainingDataset.maxLength.as[Int]))

[36mtrainingDataset[39m: [32mpy[39m.[32mDynamic[39m = <SpamDataset object at 0xffff4c3ad280>
[36mvalidationDataset[39m: [32mpy[39m.[32mDynamic[39m = <SpamDataset object at 0xffff3d31c200>
[36mtestDataset[39m: [32mpy[39m.[32mDynamic[39m = <SpamDataset object at 0xffff3d31f5f0>

In [14]:
val batchSize = 8
torch.manual_seed(123)

val trainingDataLoader = torch.utils.data.DataLoader(
  dataset = trainingDataset, 
  batch_size = batchSize,
  shuffle = true,
  num_workers = 0,
  drop_last = true
)
val validationDataLoader = torch.utils.data.DataLoader(
  dataset = validationDataset, 
  batch_size = batchSize,
  num_workers = 0,
  drop_last = false
)
val testDataLoader = torch.utils.data.DataLoader(
  dataset = testDataset, 
  batch_size = batchSize,
  num_workers = 0,
  drop_last = false
)

println(s"${py.Dynamic.global.len(trainingDataLoader)} training batches")
println(s"${py.Dynamic.global.len(validationDataLoader)} validation batches")
println(s"${py.Dynamic.global.len(testDataLoader)} test batches")

130 training batches
19 validation batches
38 test batches


[36mbatchSize[39m: [32mInt[39m = [32m8[39m
[36mres14_1[39m: [32mpy[39m.[32mDynamic[39m = <torch._C.Generator object at 0xffff4ff2c790>
[36mtrainingDataLoader[39m: [32mpy[39m.[32mDynamic[39m = <torch.utils.data.dataloader.DataLoader object at 0xffff3f3f3e90>
[36mvalidationDataLoader[39m: [32mpy[39m.[32mDynamic[39m = <torch.utils.data.dataloader.DataLoader object at 0xffff58175cd0>
[36mtestDataLoader[39m: [32mpy[39m.[32mDynamic[39m = <torch.utils.data.dataloader.DataLoader object at 0xffff3d78bd40>

In [15]:
case class GPTConfig(
  vocabularySize: Int,
  contextLength: Int,
  embeddingDimension: Int,
  attentionHeadsCount: Int,
  layersCount: Int,
  dropoutRate: Double,
  queryKeyValueBias: Boolean
)

defined [32mclass[39m [36mGPTConfig[39m

In [16]:
type TorchTensor = py.Dynamic

// Workaround to define a class that inherits from a Python class
py.exec {
  s"""import torch.nn as nn
     |
     |class MultiHeadAttention(nn.Module):
     |  def __init__(self, init):
     |    super().__init__()
     |    init(self)
     |""".stripMargin
}
def MultiHeadAttention(
  inputDimension: Int,
  outputDimension: Int,
  dropoutProbability: Double,
  contextLength: Int,
  headsCount: Int,
  queryKeyValueBias: Boolean
): py.Dynamic = {
  assert(outputDimension % headsCount == 0, "Output dimension must be a multiple of heads count")
  val headDimension = outputDimension / headsCount
    
  val init = (self: py.Dynamic) => {
    self.weightsQuery = torch.nn.Linear(inputDimension, outputDimension, bias = queryKeyValueBias)
    self.weightsKey = torch.nn.Linear(inputDimension, outputDimension, bias = queryKeyValueBias)
    self.weightsValue = torch.nn.Linear(inputDimension, outputDimension, bias = queryKeyValueBias)
    self.outputProjection = torch.nn.Linear(outputDimension, outputDimension)
    self.dropout = torch.nn.Dropout(dropoutProbability)
    self.register_buffer("mask", torch.triu(torch.ones(contextLength, contextLength), diagonal = 1))
      
    val forward = (batchedInputs: TorchTensor) => {
      val (batchesCount, tokensCount, tokenDimension) = batchedInputs.shape.as[(Int, Int, Int)]
      val queries = self.weightsQuery(batchedInputs)
        .view(batchesCount, tokensCount, headsCount, headDimension)
        .transpose(1, 2)
      val keys = self.weightsKey(batchedInputs)
        .view(batchesCount, tokensCount, headsCount, headDimension)
        .transpose(1, 2)
      val values = self.weightsValue(batchedInputs)
        .view(batchesCount, tokensCount, headsCount, headDimension)
        .transpose(1, 2)
      val attentionScores = py"$queries @ $keys.transpose(2, 3)"
      attentionScores.masked_fill_(py"${self.mask}.bool()[:$tokensCount, :$tokensCount]", -torch.inf)
      val attentionWeights = self.dropout(torch.softmax(py"$attentionScores / $headDimension**0.5", dim = -1))
      self.outputProjection(
        py"$attentionWeights @ $values"
          .transpose(1, 2)
          .reshape(batchesCount, tokensCount, outputDimension)
      )
    }
    self.forward = forward
  }
  py.Dynamic.global.MultiHeadAttention(init)
}

defined [32mtype[39m [36mTorchTensor[39m
defined [32mfunction[39m [36mMultiHeadAttention[39m

In [17]:
// Workaround to define a class that inherits from a Python class
// Because it mostly uses Python operators, it's implemented fully in Python
py.exec {
  s"""import torch
     |import torch.nn as nn
     |
     |class GELU(nn.Module):
     |  def __init__(self):
     |    super().__init__()
     |
     |  def forward(self, inputs):
     |    return 0.5 * inputs * (
     |      1 + torch.tanh(
     |        torch.sqrt(torch.tensor(2.0 / torch.pi)) * (inputs + 0.044715 * torch.pow(inputs, 3))
     |      )
     |    )
     |""".stripMargin
}
def GELU() = py.Dynamic.global.GELU()

defined [32mfunction[39m [36mGELU[39m

In [18]:
// Workaround to define a class that inherits from a Python class
py.exec {
  s"""import torch.nn as nn
     |
     |class FeedForward(nn.Module):
     |  def __init__(self, init):
     |    super().__init__()
     |    init(self)
     |""".stripMargin
}
def FeedForward(
  embeddingDimension: Int
): py.Dynamic = {
  val init = (self: py.Dynamic) => {
    self.layers = torch.nn.Sequential(
      torch.nn.Linear(embeddingDimension, 4 * embeddingDimension),
      GELU(),
      torch.nn.Linear(4 * embeddingDimension, embeddingDimension)
    )
      
    val forward = (inputs: TorchTensor) => self.layers(inputs)
    self.forward = forward
  }
  py.Dynamic.global.FeedForward(init)
}

defined [32mfunction[39m [36mFeedForward[39m

In [19]:
// Workaround to define a class that inherits from a Python class
py.exec {
  s"""import torch.nn as nn
     |
     |class NormalizationLayer(nn.Module):
     |  def __init__(self, init):
     |    super().__init__()
     |    init(self)
     |""".stripMargin
}
def NormalizationLayer(
  embeddingDimension: Int
): py.Dynamic = {
  val epsilon = 1e-5
  val init = (self: py.Dynamic) => {
    self.scale = torch.nn.Parameter(torch.ones(embeddingDimension))
    self.shift = torch.nn.Parameter(torch.zeros(embeddingDimension))
      
    val forward = (inputs: TorchTensor) => {
      val mean = inputs.mean(dim = -1, keepdim = true)
      val variance = inputs.`var`(dim = -1, keepdim = true, unbiased = false)
      val normalizedInputs = py"($inputs - $mean) / torch.sqrt($variance + $epsilon)"
      py"${self.scale} * $normalizedInputs + ${self.shift}"
    }
    self.forward = forward
  }
  py.Dynamic.global.NormalizationLayer(init)
}

defined [32mfunction[39m [36mNormalizationLayer[39m

In [20]:
import scala.util.chaining._

py.exec {
  s"""import torch.nn as nn
     |
     |class TransformerBlock(nn.Module):
     |  def __init__(self, init):
     |    super().__init__()
     |    init(self)
     |""".stripMargin
}
def TransformerBlock(
  config: GPTConfig
): py.Dynamic = {
  val init = (self: py.Dynamic) => {
    self.multiHeadAttention = MultiHeadAttention(
      inputDimension = config.embeddingDimension,
      outputDimension = config.embeddingDimension,
      dropoutProbability = config.dropoutRate,
      contextLength = config.contextLength,
      headsCount = config.attentionHeadsCount,
      queryKeyValueBias = config.queryKeyValueBias
    )
    self.feedForward = FeedForward(config.embeddingDimension)
    self.normalization1 = NormalizationLayer(config.embeddingDimension)
    self.normalization2 = NormalizationLayer(config.embeddingDimension)
    self.dropoutShortcut = torch.nn.Dropout(config.dropoutRate)
    
    val forward = (inputs: TorchTensor) => {
      val shortcut = inputs
      val newShortcut = inputs
        .pipe(self.normalization1(_))
        .pipe(self.multiHeadAttention(_))
        .pipe(self.dropoutShortcut(_))
        .pipe(o => py"$o + $shortcut")
      newShortcut
        .pipe(self.normalization2(_))
        .pipe(self.feedForward(_))
        .pipe(self.dropoutShortcut(_))
        .pipe(o => py"$o + $newShortcut")
    }
    self.forward = forward
  }
  py.Dynamic.global.TransformerBlock(init)
}

[32mimport [39m[36mscala.util.chaining._[39m
defined [32mfunction[39m [36mTransformerBlock[39m

In [21]:
// Workaround to define a class that inherits from a Python class
py.exec {
  s"""import torch.nn as nn
     |
     |class GPTModel(nn.Module):
     |  def __init__(self, init):
     |    super().__init__()
     |    init(self)
     |""".stripMargin
}
type Model = py.Dynamic
def GPTModel(
  config: GPTConfig
): Model = {
  val transformerBlocks = Seq.fill(config.layersCount)(TransformerBlock(config))
  val init = (self: py.Dynamic) => {
    self.tokenEmbeddingLayer = torch.nn.Embedding(config.vocabularySize, config.embeddingDimension)
    self.positionEmbeddingLayer = torch.nn.Embedding(config.contextLength, config.embeddingDimension)
    self.dropoutEmbeddingLayer = torch.nn.Dropout(config.dropoutRate)
    self.transformerBlocksLayer = py"nn.Sequential(*${transformerBlocks.toPythonProxy})"
    self.finalNormalizationLayer = NormalizationLayer(config.embeddingDimension)
    self.outputLayer = torch.nn.Linear(config.embeddingDimension, config.vocabularySize, bias = false)
      
    val forward = (batchedInputs: TorchTensor) => {
      val (_, sequenceLength) = batchedInputs.shape.as[(Int, Int)]
      val tokenEmbeddings = self.tokenEmbeddingLayer(batchedInputs)
      val positionEmbeddings = self.positionEmbeddingLayer(torch.arange(sequenceLength, device = batchedInputs.device))
      py"$tokenEmbeddings + $positionEmbeddings"
        .pipe(self.dropoutEmbeddingLayer(_))
        .pipe(self.transformerBlocksLayer(_))
        .pipe(self.finalNormalizationLayer(_))
        .pipe(self.outputLayer(_))
    }
    self.forward = forward
  }
  py.Dynamic.global.GPTModel(init)
}

defined [32mtype[39m [36mModel[39m
defined [32mfunction[39m [36mGPTModel[39m

In [22]:
val baseUrl = "https://f001.backblazeb2.com/file/LLMs-from-scratch/gpt2/124M" // backup
// val baseUrl = "https://openaipublic.blob.core.windows.net/gpt-2/models/124M" // backup
val hparamsFilename = "hparams.json"
val filenames = List("checkpoint", "encoder.json", hparamsFilename, "model.ckpt.data-00000-of-00001", "model.ckpt.index", "model.ckpt.meta", "vocab.bpe")

val outputDir = "data/openai124M"

[36mbaseUrl[39m: [32mString[39m = [32m"https://f001.backblazeb2.com/file/LLMs-from-scratch/gpt2/124M"[39m
[36mhparamsFilename[39m: [32mString[39m = [32m"hparams.json"[39m
[36mfilenames[39m: [32mList[39m[[32mString[39m] = [33mList[39m(
  [32m"checkpoint"[39m,
  [32m"encoder.json"[39m,
  [32m"hparams.json"[39m,
  [32m"model.ckpt.data-00000-of-00001"[39m,
  [32m"model.ckpt.index"[39m,
  [32m"model.ckpt.meta"[39m,
  [32m"vocab.bpe"[39m
)
[36moutputDir[39m: [32mString[39m = [32m"data/openai124M"[39m

In [23]:
filenames.foreach { filename =>
  println(s"Downloading $filename...")
  Magic.!("curl", "--create-dirs", "-O", "--output-dir", outputDir, s"$baseUrl/$filename")
}

Downloading checkpoint...


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100    77  100    77    0     0    112      0 --:--:-- --:--:-- --:--:--   112


Downloading encoder.json...


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  4 1017k    4 48714    0     0  48908      0  0:00:21 --:--:--  0:00:21 48909
100 1017k  100 1017k    0     0   565k      0  0:00:01  0:00:01 --:--:--  566k


Downloading hparams.json...


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100    90  100    90    0     0    138      0 --:--:-- --:--:-- --:--:--   138


Downloading model.ckpt.data-00000-of-00001...


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0  474M    0  439k    0     0   292k      0  0:27:41  0:00:01  0:27:40  292k
  2  474M    2 10.0M    0     0  4095k      0  0:01:58  0:00:02  0:01:56 4095k
  5  474M    5 25.5M    0     0  7437k      0  0:01:05  0:00:03  0:01:02 7435k
  8  474M    8 41.6M    0     0  9405k      0  0:00:51  0:00:04  0:00:47 9404k
 12  474M   12 58.1M    0     0  10.4M      0  0:00:45  0:00:05  0:00:40 11.6M
 15  474M   15 74.0M    0     0  11.2M      0  0:00:42  0:00:06  0:00:36 14.5M
 19  474M   19 90.5M    0     0  11.9M      0  0:00:39  0:00:07  0:00:32 15.9M
 22  474M   22  106M    0     0  12.4M      0  0:00:38  0:00:08  0:00:30 16.1M
 25  474M   25  121M    0     0  12.7M      0  0:00

Downloading model.ckpt.index...


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100  5215  100  5215    0     0   8799      0 --:--:-- --:--:-- --:--:--  8809


Downloading model.ckpt.meta...


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
 19  460k   19 89664    0     0  73720      0  0:00:06  0:00:01  0:00:05 73676
100  460k  100  460k    0     0   270k      0  0:00:01  0:00:01 --:--:--  270k


Downloading vocab.bpe...


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
 78  445k   78  351k    0     0   235k      0  0:00:01  0:00:01 --:--:--  235k
100  445k  100  445k    0     0   279k      0  0:00:01  0:00:01 --:--:--  279k


In [24]:
Magic.!("pip", "install", "tensorflow==2.16.*")




[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: pip install --upgrade pip


In [25]:
import $ivy.`com.lihaoyi::ujson:4.1.0`

import scala.io.Source

val hparamsMap = ujson.read(Source.fromFile(s"$outputDir/$hparamsFilename").mkString)

val gptConfig = GPTConfig(
  vocabularySize = hparamsMap("n_vocab").num.toInt,
  contextLength = hparamsMap("n_ctx").num.toInt,
  embeddingDimension = hparamsMap("n_embd").num.toInt,
  attentionHeadsCount = hparamsMap("n_head").num.toInt,
  layersCount = hparamsMap("n_layer").num.toInt,
  dropoutRate = 0.1,
  queryKeyValueBias = true
)

[32mimport [39m[36m$ivy.$[39m
[32mimport [39m[36mscala.io.Source[39m
[36mhparamsMap[39m: [32mujson[39m.[32mValue[39m.[32mValue[39m = [33mObj[39m(
  value = [33mMap[39m(
    [32m"n_vocab"[39m -> [33mNum[39m(value = [32m50257.0[39m),
    [32m"n_ctx"[39m -> [33mNum[39m(value = [32m1024.0[39m),
    [32m"n_embd"[39m -> [33mNum[39m(value = [32m768.0[39m),
    [32m"n_head"[39m -> [33mNum[39m(value = [32m12.0[39m),
    [32m"n_layer"[39m -> [33mNum[39m(value = [32m12.0[39m)
  )
)
[36mgptConfig[39m: [32mGPTConfig[39m = [33mGPTConfig[39m(
  vocabularySize = [32m50257[39m,
  contextLength = [32m1024[39m,
  embeddingDimension = [32m768[39m,
  attentionHeadsCount = [32m12[39m,
  layersCount = [32m12[39m,
  dropoutRate = [32m0.1[39m,
  queryKeyValueBias = [32mtrue[39m
)

In [26]:
val tf = py.module("tensorflow")
val np = py.module("numpy")

[36mtf[39m: [32mpy[39m.[32mModule[39m = <module 'tensorflow' from '/usr/local/lib/python3.12/site-packages/tensorflow/__init__.py'>
[36mnp[39m: [32mpy[39m.[32mModule[39m = <module 'numpy' from '/usr/local/lib/python3.12/site-packages/numpy/__init__.py'>

In [27]:
val checkpoint = tf.train.latest_checkpoint(outputDir)
val variableNames = tf.train.list_variables(checkpoint).as[Seq[(String, Seq[Int])]].map { 
  case (variableName, _) => variableName 
}.toList
variableNames.sorted.foreach(println)

model/h0/attn/c_attn/b
model/h0/attn/c_attn/w
model/h0/attn/c_proj/b
model/h0/attn/c_proj/w
model/h0/ln_1/b
model/h0/ln_1/g
model/h0/ln_2/b
model/h0/ln_2/g
model/h0/mlp/c_fc/b
model/h0/mlp/c_fc/w
model/h0/mlp/c_proj/b
model/h0/mlp/c_proj/w
model/h1/attn/c_attn/b
model/h1/attn/c_attn/w
model/h1/attn/c_proj/b
model/h1/attn/c_proj/w
model/h1/ln_1/b
model/h1/ln_1/g
model/h1/ln_2/b
model/h1/ln_2/g
model/h1/mlp/c_fc/b
model/h1/mlp/c_fc/w
model/h1/mlp/c_proj/b
model/h1/mlp/c_proj/w
model/h10/attn/c_attn/b
model/h10/attn/c_attn/w
model/h10/attn/c_proj/b
model/h10/attn/c_proj/w
model/h10/ln_1/b
model/h10/ln_1/g
model/h10/ln_2/b
model/h10/ln_2/g
model/h10/mlp/c_fc/b
model/h10/mlp/c_fc/w
model/h10/mlp/c_proj/b
model/h10/mlp/c_proj/w
model/h11/attn/c_attn/b
model/h11/attn/c_attn/w
model/h11/attn/c_proj/b
model/h11/attn/c_proj/w
model/h11/ln_1/b
model/h11/ln_1/g
model/h11/ln_2/b
model/h11/ln_2/g
model/h11/mlp/c_fc/b
model/h11/mlp/c_fc/w
model/h11/mlp/c_proj/b
model/h11/mlp/c_proj/w
model/h2/attn/c_

[36mcheckpoint[39m: [32mpy[39m.[32mDynamic[39m = data/openai124M/model.ckpt
[36mvariableNames[39m: [32mList[39m[[32mString[39m] = [33mList[39m(
  [32m"model/h0/attn/c_attn/b"[39m,
  [32m"model/h0/attn/c_attn/w"[39m,
  [32m"model/h0/attn/c_proj/b"[39m,
  [32m"model/h0/attn/c_proj/w"[39m,
  [32m"model/h0/ln_1/b"[39m,
  [32m"model/h0/ln_1/g"[39m,
  [32m"model/h0/ln_2/b"[39m,
  [32m"model/h0/ln_2/g"[39m,
  [32m"model/h0/mlp/c_fc/b"[39m,
  [32m"model/h0/mlp/c_fc/w"[39m,
  [32m"model/h0/mlp/c_proj/b"[39m,
  [32m"model/h0/mlp/c_proj/w"[39m,
  [32m"model/h1/attn/c_attn/b"[39m,
  [32m"model/h1/attn/c_attn/w"[39m,
  [32m"model/h1/attn/c_proj/b"[39m,
  [32m"model/h1/attn/c_proj/w"[39m,
  [32m"model/h1/ln_1/b"[39m,
  [32m"model/h1/ln_1/g"[39m,
  [32m"model/h1/ln_2/b"[39m,
  [32m"model/h1/ln_2/g"[39m,
  [32m"model/h1/mlp/c_fc/b"[39m,
  [32m"model/h1/mlp/c_fc/w"[39m,
  [32m"model/h1/mlp/c_proj/b"[39m,
  [32m"model/h1/mlp/c_proj/w"[39m,
  

In [28]:
type NpArray = py.Dynamic

def toTorchParameter(npArray: NpArray) =
  torch.nn.Parameter(torch.tensor(npArray))

def loadModelWeights(model: Model): Unit =
  variableNames.foreach { variableName =>
    val variableValue = np.squeeze(tf.train.load_variable(checkpoint, variableName))
    variableName.split("/").drop(1).toList match {
      case s"h$transformerBlockIndexString" :: tail =>
        val transformerBlockIndex = transformerBlockIndexString.toInt
        tail match {
          case "attn" :: tail =>
            val multiHeadAttention = model.transformerBlocksLayer.bracketAccess(transformerBlockIndex).multiHeadAttention
            tail match {
              case "c_attn" :: tail =>
                val Seq(queryVariableValue, keyVariableValue, valueVariableValue) = np.split(variableValue, 3, axis = -1).as[Seq[NpArray]]
                tail match {
                  case "b" :: _ => 
                    multiHeadAttention.weightsQuery.bias = toTorchParameter(queryVariableValue)
                    multiHeadAttention.weightsKey.bias = toTorchParameter(keyVariableValue)
                    multiHeadAttention.weightsValue.bias = toTorchParameter(valueVariableValue)
                  case "w" :: _ => 
                    multiHeadAttention.weightsQuery.weight = toTorchParameter(queryVariableValue.T)
                    multiHeadAttention.weightsKey.weight = toTorchParameter(keyVariableValue.T)
                    multiHeadAttention.weightsValue.weight = toTorchParameter(valueVariableValue.T)
                }
              case "c_proj" :: tail =>
                tail match {
                  case "b" :: _ => multiHeadAttention.outputProjection.bias = toTorchParameter(variableValue)
                  case "w" :: _ => multiHeadAttention.outputProjection.weight = toTorchParameter(variableValue.T)
                }
            }
          case "ln_1" :: tail =>
            val normalization1 = model.transformerBlocksLayer.bracketAccess(transformerBlockIndex).normalization1
            val torchParameter = toTorchParameter(variableValue)
            tail match {
              case "b" :: _ => normalization1.shift = torchParameter
              case "g" :: _ => normalization1.scale = torchParameter
            }
          case "ln_2" :: tail =>
            val normalization2 = model.transformerBlocksLayer.bracketAccess(transformerBlockIndex).normalization2
            val torchParameter = toTorchParameter(variableValue)
            tail match {
              case "b" :: _ => normalization2.shift = torchParameter
              case "g" :: _ => normalization2.scale = torchParameter
            }
          case "mlp" :: tail =>
            val feedForward = model.transformerBlocksLayer.bracketAccess(transformerBlockIndex).feedForward
            tail match {
              case "c_fc" :: tail =>
                val layer0 = feedForward.layers.bracketAccess(0)
                tail match {
                  case "b" :: _ => layer0.bias = toTorchParameter(variableValue)
                  case "w" :: _ => layer0.weight = toTorchParameter(variableValue.T)
                }
              case "c_proj" :: tail =>
                val layer2 = feedForward.layers.bracketAccess(2)
                tail match {
                  case "b" :: _ => layer2.bias = toTorchParameter(variableValue)
                  case "w" :: _ => layer2.weight = toTorchParameter(variableValue.T)
                }
            }
        }
      case "ln_f" :: tail =>
        val finalNormalizationLayer = model.finalNormalizationLayer
        val torchParameter = toTorchParameter(variableValue)
        tail match {
          case "b" :: _ => finalNormalizationLayer.shift = torchParameter
          case "g" :: _ => finalNormalizationLayer.scale = torchParameter
        }
      case "wpe" :: _ => model.positionEmbeddingLayer.weight = toTorchParameter(variableValue)
      case "wte" :: _ => 
        val torchParameter = toTorchParameter(variableValue)
        model.tokenEmbeddingLayer.weight = torchParameter
        model.outputLayer.weight = torchParameter
    }
  }

cmd28.sc:18: match may not be exhaustive.
It would fail on the following inputs: List((x: String forSome x not in ("b", "w"))), Nil
                tail match {
                ^
cmd28.sc:29: match may not be exhaustive.
It would fail on the following inputs: List((x: String forSome x not in ("b", "w"))), Nil
                tail match {
                ^
cmd28.sc:15: match may not be exhaustive.
It would fail on the following inputs: List((x: String forSome x not in ("c_attn", "c_proj"))), Nil
            tail match {
            ^
cmd28.sc:37: match may not be exhaustive.
It would fail on the following inputs: List((x: String forSome x not in ("b", "g"))), Nil
            tail match {
            ^
cmd28.sc:44: match may not be exhaustive.
It would fail on the following inputs: List((x: String forSome x not in ("b", "g"))), Nil
            tail match {
            ^
cmd28.sc:53: match may not be exhaustive.
It would fail on the following inputs: List((x: String forSome x not in ("b",

defined [32mtype[39m [36mNpArray[39m
defined [32mfunction[39m [36mtoTorchParameter[39m
defined [32mfunction[39m [36mloadModelWeights[39m

In [29]:
val model = GPTModel(gptConfig)
loadModelWeights(model)
val device = torch.device(if (torch.cuda.is_available().as[Boolean]) "cuda" else "cpu")
model.to(device)
model.eval()

[36mmodel[39m: [32mModel[39m = GPTModel(
  (tokenEmbeddingLayer): Embedding(50257, 768)
  (positionEmbeddingLayer): Embedding(1024, 768)
  (dropoutEmbeddingLayer): Dropout(p=0.1, inplace=False)
  (transformerBlocksLayer): Sequential(
    (0): TransformerBlock(
      (multiHeadAttention): MultiHeadAttention(
        (weightsQuery): Linear(in_features=768, out_features=768, bias=True)
        (weightsKey): Linear(in_features=768, out_features=768, bias=True)
        (weightsValue): Linear(in_features=768, out_features=768, bias=True)
        (outputProjection): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feedForward): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (normalization1): NormalizationLayer()
      (normalization2): Normaliza

In [30]:
def textToTokenIds(
  text: String, 
  tokenizer: Tokenizer
): TorchTensor = {
  val allowedSpecial = py.Dynamic.global.set(Seq("<|endoftext|>").toPythonProxy)
  val encodedText = tokenizer.encode(text, allowed_special = allowedSpecial)
  torch.tensor(encodedText).unsqueeze(0)
}
    
def tokenIdsToText(
  tokenIds: TorchTensor, 
  tokenizer: Tokenizer
): String =
  tokenizer.decode(tokenIds.squeeze(0).tolist()).as[String]

def generateTextSimple(
  model: Model,
  maxNewTokens: Int,
  contextLength: Int
)(
  encodedInput: TorchTensor
): TorchTensor =
  LazyList.iterate(encodedInput) { currentEncodedOutput =>
    val croppedInput = py"$currentEncodedOutput[:, -$contextLength:]"
    val logits = py.`with`(torch.no_grad()) { _ =>
      model(croppedInput)
    }
    py"$logits[:, -1, :]"
      .pipe(torch.softmax(_, dim = -1))
      .pipe(torch.argmax(_, dim = -1, keepdim = true))
      .pipe(nextEncodedOutput => torch.cat((currentEncodedOutput, nextEncodedOutput), dim = 1))
  }.drop(maxNewTokens).head

defined [32mfunction[39m [36mtextToTokenIds[39m
defined [32mfunction[39m [36mtokenIdsToText[39m
defined [32mfunction[39m [36mgenerateTextSimple[39m

In [31]:
val exampleText = "Every effort moves you"
val encodedText = textToTokenIds(exampleText, tokenizer)
val encodedTextOutput = generateTextSimple(model, maxNewTokens = 15, contextLength = gptConfig.contextLength)(encodedText)
val decodedTextOutput = tokenIdsToText(encodedTextOutput, tokenizer)
println(decodedTextOutput)

Every effort moves you forward.

The first step is to understand the importance of your work


[36mexampleText[39m: [32mString[39m = [32m"Every effort moves you"[39m
[36mencodedText[39m: [32mTorchTensor[39m = tensor([[6109, 3626, 6100,  345]])
[36mencodedTextOutput[39m: [32mTorchTensor[39m = tensor([[6109, 3626, 6100,  345, 2651,   13,  198,  198,  464,  717, 2239,  318,
          284, 1833,  262, 6817,  286,  534,  670]])
[36mdecodedTextOutput[39m: [32mString[39m = [32m"""Every effort moves you forward.

The first step is to understand the importance of your work"""[39m

In [32]:
val spamClassificationPrompt = "Is the following text spam? Answer with 'yes' or 'no': 'You are a winner you have been specially selected to receive $1000 cash or a $2000 award.'"
val spamClassificationPromptAnswer = 
  textToTokenIds(spamClassificationPrompt, tokenizer)
    .pipe(generateTextSimple(model, maxNewTokens = 23, contextLength = gptConfig.contextLength))
    .pipe(tokenIdsToText(_, tokenizer))
println(spamClassificationPromptAnswer)

Is the following text spam? Answer with 'yes' or 'no': 'You are a winner you have been specially selected to receive $1000 cash or a $2000 award.'

The following text spam? Answer with 'yes' or 'no': 'You are a winner you have


[36mspamClassificationPrompt[39m: [32mString[39m = [32m"Is the following text spam? Answer with 'yes' or 'no': 'You are a winner you have been specially selected to receive $1000 cash or a $2000 award.'"[39m
[36mspamClassificationPromptAnswer[39m: [32mString[39m = [32m"""Is the following text spam? Answer with 'yes' or 'no': 'You are a winner you have been specially selected to receive $1000 cash or a $2000 award.'

The following text spam? Answer with 'yes' or 'no': 'You are a winner you have"""[39m

In [33]:
import scala.annotation.tailrec

def foreachPy(iterable: py.Dynamic)(f: py.Dynamic => Unit): Unit = {
  val iterator = py"iter($iterable)"

  @tailrec
  def loop(): Unit = {
    val currentValue = py"next($iterator, None)"
    if (currentValue != py.Dynamic.global.None) {
      f(currentValue)
      loop()
    }
  }

  loop()
}

[32mimport [39m[36mscala.annotation.tailrec[39m
defined [32mfunction[39m [36mforeachPy[39m

In [34]:
foreachPy(model.parameters()) { parameter =>
  parameter.requires_grad = false
}

torch.manual_seed(123)
val classesCount = 2
model.outputLayer = torch.nn.Linear(
  in_features = gptConfig.embeddingDimension,
  out_features = classesCount
)

foreachPy(model.transformerBlocksLayer.bracketAccess(-1).parameters()) { parameter =>
  parameter.requires_grad = true
}

foreachPy(model.finalNormalizationLayer.parameters()) { parameter =>
  parameter.requires_grad = true
}

[36mres34_1[39m: [32mpy[39m.[32mDynamic[39m = <torch._C.Generator object at 0xffff4ff2c790>
[36mclassesCount[39m: [32mInt[39m = [32m2[39m

In [35]:
type Device = py.Dynamic
type DataLoader = py.Dynamic

def calculateDataLoaderAccuracy(
  model: Model,
  device: Device
)(
  dataLoader: DataLoader,
  batchesCountOpt: Option[Int] = None
): Double = { 
  val batchesCount = batchesCountOpt match {
    case Some(batchesCount) => batchesCount
    case None => py"len($dataLoader)".as[Int]
  }
  assert(batchesCount > 0, "There were no batches to process")
  var correctPredictions = 0
  var examplesSeen = 0
  var currentBatchIndex = 0
  foreachPy(dataLoader) { currentBatch =>
    if (currentBatchIndex < batchesCount) 
      py.local {
        val Seq(inputBatch, targetBatch) = currentBatch.as[Seq[TorchTensor]]
        val logits = py.`with`(torch.no_grad()) { _ =>
          model(inputBatch.to(device))
        }
        val predictedClasses = torch.argmax(py"$logits[:, -1, :]", dim = -1)
        examplesSeen += predictedClasses.shape.bracketAccess(0).as[Int]
        correctPredictions += py"$predictedClasses == $targetBatch".sum().item().as[Int]
      }
    currentBatchIndex += 1
  }
  correctPredictions.toDouble / examplesSeen
}

defined [32mtype[39m [36mDevice[39m
defined [32mtype[39m [36mDataLoader[39m
defined [32mfunction[39m [36mcalculateDataLoaderAccuracy[39m

In [36]:
torch.manual_seed(123)

val trainingAccuracy = calculateDataLoaderAccuracy(model, device)(trainingDataLoader, batchesCountOpt = Some(10))
println(f"Training accuracy: ${trainingAccuracy * 100}%.2f%%")
val validationAccuracy = calculateDataLoaderAccuracy(model, device)(validationDataLoader, batchesCountOpt = Some(10))
println(f"Validation accuracy: ${validationAccuracy * 100}%02f%%")
val testAccuracy = calculateDataLoaderAccuracy(model, device)(testDataLoader, batchesCountOpt = Some(10))
println(f"Test accuracy: ${testAccuracy * 100}%02f%%")

Training accuracy: 66.25%
Validation accuracy: 47.500000%
Test accuracy: 42.500000%


[36mres36_0[39m: [32mpy[39m.[32mDynamic[39m = <torch._C.Generator object at 0xffff4ff2c790>
[36mtrainingAccuracy[39m: [32mDouble[39m = [32m0.6625[39m
[36mvalidationAccuracy[39m: [32mDouble[39m = [32m0.475[39m
[36mtestAccuracy[39m: [32mDouble[39m = [32m0.425[39m

In [37]:
def calculateBatchLoss(
  model: Model,
  device: Device
)(
  inputBatch: TorchTensor,
  targetBatch: TorchTensor
): TorchTensor = {
  val logits = model(inputBatch.to(device))
  torch.nn.functional.cross_entropy(py"$logits[:, -1, :]", targetBatch)
}

def calculateDataLoaderLoss(
  model: Model,
  device: Device
)(
  dataLoader: DataLoader,
  batchesCountOpt: Option[Int] = None
): Double = { 
  val batchesCount = batchesCountOpt match {
    case Some(batchesCount) => batchesCount
    case None => py"len($dataLoader)".as[Int]
  }
  assert(batchesCount > 0, "There were no batches to process")
  var totalLoss = 0.0
  var currentBatchIndex = 0
  foreachPy(dataLoader) { currentBatch =>
    if (currentBatchIndex < batchesCount) 
      py.local {
        val Seq(inputBatch, targetBatch) = currentBatch.as[Seq[TorchTensor]]
        totalLoss += calculateBatchLoss(model, device)(inputBatch, targetBatch).item().as[Double]
      }
    currentBatchIndex += 1
  }
  totalLoss / batchesCount
}

defined [32mfunction[39m [36mcalculateBatchLoss[39m
defined [32mfunction[39m [36mcalculateDataLoaderLoss[39m

In [38]:
py.`with`(torch.no_grad()) { _ =>
  val trainingLoss = calculateDataLoaderLoss(model, device)(trainingDataLoader, batchesCountOpt = Some(5))
  println(s"Training loss: $trainingLoss")
  val validationLoss = calculateDataLoaderLoss(model, device)(validationDataLoader, batchesCountOpt = Some(5))
  println(s"Validation loss: $validationLoss")
  val testLoss = calculateDataLoaderLoss(model, device)(testDataLoader, batchesCountOpt = Some(5))
  println(s"Test loss: $testLoss")
}

Training loss: 1.2273894667625427
Validation loss: 2.913743782043457
Test loss: 2.6839645385742186
