From 06978476665b6e17864cbd175c3dd9ab9fe73b5c Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Thu, 11 Jan 2024 18:12:10 +1100 Subject: [PATCH] fix: don't include generic types in variants --- bin/hermit.hcl | 3 +++ cmd/go-check-sumtype/main.go | 4 +++- decl.go | 1 + def.go | 16 ++++++++++++++++ scripts/go-check-sumtype | 7 +++++++ testdata/sum.go | 23 +++++++++++++++++++++++ 6 files changed, 53 insertions(+), 1 deletion(-) create mode 100755 scripts/go-check-sumtype create mode 100644 testdata/sum.go diff --git a/bin/hermit.hcl b/bin/hermit.hcl index e69de29..6084415 100644 --- a/bin/hermit.hcl +++ b/bin/hermit.hcl @@ -0,0 +1,3 @@ +env = { + "PATH": "${HERMIT_ENV}/scripts:${PATH}", +} diff --git a/cmd/go-check-sumtype/main.go b/cmd/go-check-sumtype/main.go index 6ef0878..67df985 100644 --- a/cmd/go-check-sumtype/main.go +++ b/cmd/go-check-sumtype/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "log" "os" "strings" @@ -11,7 +12,8 @@ import ( func main() { log.SetFlags(0) - if len(os.Args) < 2 { + flag.Parse() + if len(flag.Args()) < 1 { log.Fatalf("Usage: sumtype \n") } args := os.Args[1:] diff --git a/decl.go b/decl.go index ea2cd06..9dec9ee 100644 --- a/decl.go +++ b/decl.go @@ -57,6 +57,7 @@ func findSumTypeDecls(pkgs []*packages.Package) ([]sumTypeDecl, error) { } pos = pkg.Fset.Position(tspec.Pos()) decl := sumTypeDecl{Package: pkg, TypeName: tspec.Name.Name, Pos: pos} + debugf("found sum type decl: %s.%s", decl.Package.PkgPath, decl.TypeName) decls = append(decls, decl) break } diff --git a/def.go b/def.go index 811b98f..24729ac 100644 --- a/def.go +++ b/def.go @@ -1,11 +1,21 @@ package gochecksumtype import ( + "flag" "fmt" "go/token" "go/types" + "log" ) +var debug = flag.Bool("debug", false, "enable debug logging") + +func debugf(format string, args ...interface{}) { + if *debug { + log.Printf(format, args...) + } +} + // Error as returned by Run() type Error interface { error @@ -107,6 +117,7 @@ func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) { Decl: decl, Ty: iface, } + debugf("searching for variants of %s.%s\n", pkg.Path(), decl.TypeName) for _, name := range pkg.Scope().Names() { obj, ok := pkg.Scope().Lookup(name).(*types.TypeName) if !ok { @@ -116,7 +127,12 @@ func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) { if types.Identical(ty.Underlying(), iface) { continue } + // Skip generic types. + if named, ok := ty.(*types.Named); ok && named.TypeParams() != nil { + continue + } if types.Implements(ty, iface) || types.Implements(types.NewPointer(ty), iface) { + debugf(" found variant: %s.%s\n", pkg.Path(), obj.Name()) def.Variants = append(def.Variants, obj) } } diff --git a/scripts/go-check-sumtype b/scripts/go-check-sumtype new file mode 100755 index 0000000..096300e --- /dev/null +++ b/scripts/go-check-sumtype @@ -0,0 +1,7 @@ +#!/bin/bash +set -euo pipefail +basedir="$(dirname "$0")/.." +name="$(basename "$0")" +dest="${basedir}/build/devel" +mkdir -p "$dest" +(cd "${basedir}" && ./bin/go build -ldflags="-s -w -buildid=" -o "$dest/${name}" "./cmd/${name}") && exec "$dest/${name}" "$@" diff --git a/testdata/sum.go b/testdata/sum.go new file mode 100644 index 0000000..8c993cf --- /dev/null +++ b/testdata/sum.go @@ -0,0 +1,23 @@ +package testdata + +//sumtype:decl +type Sum interface{ sum() } + +type A struct{} + +func (A) sum() {} + +type B struct{} + +func (B) sum() {} + +type C[T any] struct{} + +func (C[T]) sum() {} + +func SumSwitch(x Sum) { + switch x.(type) { + case A: + case B: + } +}